diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 4fdcc9ecd20..a68f99425cb 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1791,6 +1791,7 @@ dependencies = [ "eventsource-stream", "futures", "http 1.4.0", + "iana-time-zone", "image", "indexmap 2.13.0", "insta", @@ -2299,6 +2300,7 @@ dependencies = [ "anyhow", "async-trait", "clap", + "codex-protocol", "codex-utils-absolute-path", "libc", "pretty_assertions", diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index c72bb907f3c..01836cde05d 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -118,8 +118,8 @@ codex-shell-command = { path = "shell-command" } codex-shell-escalation = { path = "shell-escalation" } codex-skills = { path = "skills" } codex-state = { path = "state" } -codex-test-macros = { path = "test-macros" } codex-stdio-to-uds = { path = "stdio-to-uds" } +codex-test-macros = { path = "test-macros" } codex-tui = { path = "tui" } codex-utils-absolute-path = { path = "utils/absolute-path" } codex-utils-approval-presets = { path = "utils/approval-presets" } @@ -137,8 +137,8 @@ codex-utils-readiness = { path = "utils/readiness" } codex-utils-rustls-provider = { path = "utils/rustls-provider" } codex-utils-sandbox-summary = { path = "utils/sandbox-summary" } codex-utils-sleep-inhibitor = { path = "utils/sleep-inhibitor" } -codex-utils-string = { path = "utils/string" } codex-utils-stream-parser = { path = "utils/stream-parser" } +codex-utils-string = { path = "utils/string" } codex-windows-sandbox = { path = "windows-sandbox-rs" } core_test_support = { path = "core/tests/common" } mcp_test_support = { path = "mcp-server/tests/common" } @@ -165,8 +165,8 @@ clap = "4" clap_complete = "4" color-eyre = "0.6.3" crossbeam-channel = "0.5.15" -csv = "1.3.1" crossterm = "0.28.1" +csv = "1.3.1" ctor = "0.6.3" derive_more = "2" diffy = "0.4.2" @@ -178,14 +178,15 @@ env-flags = "0.1.1" env_logger = "0.11.9" eventsource-stream = "0.2.3" futures = { version = "0.3", default-features = false } -globset = "0.4" gethostname = "1.1.0" +globset = "0.4" http = "1.3.1" icu_decimal = "2.1" icu_locale_core = "2.1" icu_provider = { version = "2.1", features = ["sync"] } ignore = "0.4.23" image = { version = "^0.25.9", default-features = false } +iana-time-zone = "0.1.64" include_dir = "0.7.4" indexmap = "2.12.0" insta = "1.46.3" @@ -258,6 +259,7 @@ starlark = "0.13.0" strum = "0.27.2" strum_macros = "0.27.2" supports-color = "3.0.2" +syntect = "5" sys-locale = "0.3.2" tempfile = "3.23.0" test-log = "0.2.19" @@ -282,7 +284,6 @@ tracing-subscriber = "0.3.22" tracing-test = "0.2.5" tree-sitter = "0.25.10" tree-sitter-bash = "0.25" -syntect = "5" ts-rs = "11" tungstenite = { version = "0.27.0", features = ["deflate", "proxy"] } uds_windows = "1.1.0" @@ -352,6 +353,7 @@ ignored = [ [profile.release] lto = "fat" +split-debuginfo = "off" # Because we bundle some of these executables with the TypeScript CLI, we # remove everything to make the binary as small as possible. strip = "symbols" diff --git a/codex-rs/app-server-protocol/schema/json/ClientRequest.json b/codex-rs/app-server-protocol/schema/json/ClientRequest.json index 5ab197c84bc..03e42bed45e 100644 --- a/codex-rs/app-server-protocol/schema/json/ClientRequest.json +++ b/codex-rs/app-server-protocol/schema/json/ClientRequest.json @@ -1340,7 +1340,7 @@ "type": "string" }, "output": { - "type": "string" + "$ref": "#/definitions/FunctionCallOutputPayload" }, "type": { "enum": [ diff --git a/codex-rs/app-server-protocol/schema/json/EventMsg.json b/codex-rs/app-server-protocol/schema/json/EventMsg.json index c7bf9087480..9f442ece9dc 100644 --- a/codex-rs/app-server-protocol/schema/json/EventMsg.json +++ b/codex-rs/app-server-protocol/schema/json/EventMsg.json @@ -4822,7 +4822,7 @@ "type": "string" }, "output": { - "type": "string" + "$ref": "#/definitions/FunctionCallOutputPayload" }, "type": { "enum": [ diff --git a/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.schemas.json b/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.schemas.json index 62850442fc3..a107340464a 100644 --- a/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.schemas.json +++ b/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.schemas.json @@ -10225,6 +10225,16 @@ }, "Model": { "properties": { + "availabilityNux": { + "anyOf": [ + { + "$ref": "#/definitions/v2/ModelAvailabilityNux" + }, + { + "type": "null" + } + ] + }, "defaultReasoningEffort": { "$ref": "#/definitions/v2/ReasoningEffort" }, @@ -10271,6 +10281,16 @@ "string", "null" ] + }, + "upgradeInfo": { + "anyOf": [ + { + "$ref": "#/definitions/v2/ModelUpgradeInfo" + }, + { + "type": "null" + } + ] } }, "required": [ @@ -10285,6 +10305,17 @@ ], "type": "object" }, + "ModelAvailabilityNux": { + "properties": { + "message": { + "type": "string" + } + }, + "required": [ + "message" + ], + "type": "object" + }, "ModelListParams": { "$schema": "http://json-schema.org/draft-07/schema#", "properties": { @@ -10373,6 +10404,35 @@ "title": "ModelReroutedNotification", "type": "object" }, + "ModelUpgradeInfo": { + "properties": { + "migrationMarkdown": { + "type": [ + "string", + "null" + ] + }, + "model": { + "type": "string" + }, + "modelLink": { + "type": [ + "string", + "null" + ] + }, + "upgradeCopy": { + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "model" + ], + "type": "object" + }, "NetworkAccess": { "enum": [ "restricted", @@ -11382,7 +11442,7 @@ "type": "string" }, "output": { - "type": "string" + "$ref": "#/definitions/v2/FunctionCallOutputPayload" }, "type": { "enum": [ diff --git a/codex-rs/app-server-protocol/schema/json/v2/ModelListResponse.json b/codex-rs/app-server-protocol/schema/json/v2/ModelListResponse.json index 2afa018dc50..4023f76b51b 100644 --- a/codex-rs/app-server-protocol/schema/json/v2/ModelListResponse.json +++ b/codex-rs/app-server-protocol/schema/json/v2/ModelListResponse.json @@ -22,6 +22,16 @@ }, "Model": { "properties": { + "availabilityNux": { + "anyOf": [ + { + "$ref": "#/definitions/ModelAvailabilityNux" + }, + { + "type": "null" + } + ] + }, "defaultReasoningEffort": { "$ref": "#/definitions/ReasoningEffort" }, @@ -68,6 +78,16 @@ "string", "null" ] + }, + "upgradeInfo": { + "anyOf": [ + { + "$ref": "#/definitions/ModelUpgradeInfo" + }, + { + "type": "null" + } + ] } }, "required": [ @@ -82,6 +102,46 @@ ], "type": "object" }, + "ModelAvailabilityNux": { + "properties": { + "message": { + "type": "string" + } + }, + "required": [ + "message" + ], + "type": "object" + }, + "ModelUpgradeInfo": { + "properties": { + "migrationMarkdown": { + "type": [ + "string", + "null" + ] + }, + "model": { + "type": "string" + }, + "modelLink": { + "type": [ + "string", + "null" + ] + }, + "upgradeCopy": { + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "model" + ], + "type": "object" + }, "ReasoningEffort": { "description": "See https://platform.openai.com/docs/guides/reasoning?api-mode=responses#get-started-with-reasoning", "enum": [ diff --git a/codex-rs/app-server-protocol/schema/json/v2/RawResponseItemCompletedNotification.json b/codex-rs/app-server-protocol/schema/json/v2/RawResponseItemCompletedNotification.json index 748eeaab463..4717ff266be 100644 --- a/codex-rs/app-server-protocol/schema/json/v2/RawResponseItemCompletedNotification.json +++ b/codex-rs/app-server-protocol/schema/json/v2/RawResponseItemCompletedNotification.json @@ -565,7 +565,7 @@ "type": "string" }, "output": { - "type": "string" + "$ref": "#/definitions/FunctionCallOutputPayload" }, "type": { "enum": [ diff --git a/codex-rs/app-server-protocol/schema/json/v2/ThreadResumeParams.json b/codex-rs/app-server-protocol/schema/json/v2/ThreadResumeParams.json index 29d6fbc6d74..ef7607d3ca2 100644 --- a/codex-rs/app-server-protocol/schema/json/v2/ThreadResumeParams.json +++ b/codex-rs/app-server-protocol/schema/json/v2/ThreadResumeParams.json @@ -615,7 +615,7 @@ "type": "string" }, "output": { - "type": "string" + "$ref": "#/definitions/FunctionCallOutputPayload" }, "type": { "enum": [ diff --git a/codex-rs/app-server-protocol/schema/typescript/ResponseItem.ts b/codex-rs/app-server-protocol/schema/typescript/ResponseItem.ts index 611c7fb22db..dd7621f01d6 100644 --- a/codex-rs/app-server-protocol/schema/typescript/ResponseItem.ts +++ b/codex-rs/app-server-protocol/schema/typescript/ResponseItem.ts @@ -15,4 +15,4 @@ export type ResponseItem = { "type": "message", role: string, content: Array, defaultReasoningEffort: ReasoningEffort, inputModalities: Array, supportsPersonality: boolean, isDefault: boolean, }; +export type Model = { id: string, model: string, upgrade: string | null, upgradeInfo: ModelUpgradeInfo | null, availabilityNux: ModelAvailabilityNux | null, displayName: string, description: string, hidden: boolean, supportedReasoningEfforts: Array, defaultReasoningEffort: ReasoningEffort, inputModalities: Array, supportsPersonality: boolean, isDefault: boolean, }; diff --git a/codex-rs/app-server-protocol/schema/typescript/v2/ModelAvailabilityNux.ts b/codex-rs/app-server-protocol/schema/typescript/v2/ModelAvailabilityNux.ts new file mode 100644 index 00000000000..7254aaec9b7 --- /dev/null +++ b/codex-rs/app-server-protocol/schema/typescript/v2/ModelAvailabilityNux.ts @@ -0,0 +1,5 @@ +// GENERATED CODE! DO NOT MODIFY BY HAND! + +// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. + +export type ModelAvailabilityNux = { message: string, }; diff --git a/codex-rs/app-server-protocol/schema/typescript/v2/ModelUpgradeInfo.ts b/codex-rs/app-server-protocol/schema/typescript/v2/ModelUpgradeInfo.ts new file mode 100644 index 00000000000..82d73e9d062 --- /dev/null +++ b/codex-rs/app-server-protocol/schema/typescript/v2/ModelUpgradeInfo.ts @@ -0,0 +1,5 @@ +// GENERATED CODE! DO NOT MODIFY BY HAND! + +// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. + +export type ModelUpgradeInfo = { model: string, upgradeCopy: string | null, modelLink: string | null, migrationMarkdown: string | null, }; diff --git a/codex-rs/app-server-protocol/schema/typescript/v2/index.ts b/codex-rs/app-server-protocol/schema/typescript/v2/index.ts index 4c638c7d30d..4a881427d90 100644 --- a/codex-rs/app-server-protocol/schema/typescript/v2/index.ts +++ b/codex-rs/app-server-protocol/schema/typescript/v2/index.ts @@ -107,10 +107,12 @@ export type { McpToolCallResult } from "./McpToolCallResult"; export type { McpToolCallStatus } from "./McpToolCallStatus"; export type { MergeStrategy } from "./MergeStrategy"; export type { Model } from "./Model"; +export type { ModelAvailabilityNux } from "./ModelAvailabilityNux"; export type { ModelListParams } from "./ModelListParams"; export type { ModelListResponse } from "./ModelListResponse"; export type { ModelRerouteReason } from "./ModelRerouteReason"; export type { ModelReroutedNotification } from "./ModelReroutedNotification"; +export type { ModelUpgradeInfo } from "./ModelUpgradeInfo"; export type { NetworkAccess } from "./NetworkAccess"; export type { NetworkApprovalContext } from "./NetworkApprovalContext"; export type { NetworkApprovalProtocol } from "./NetworkApprovalProtocol"; diff --git a/codex-rs/app-server-protocol/src/protocol/v2.rs b/codex-rs/app-server-protocol/src/protocol/v2.rs index f7c4eec7a7e..32bc120f6dd 100644 --- a/codex-rs/app-server-protocol/src/protocol/v2.rs +++ b/codex-rs/app-server-protocol/src/protocol/v2.rs @@ -31,6 +31,7 @@ use codex_protocol::models::MessagePhase; use codex_protocol::models::PermissionProfile as CorePermissionProfile; use codex_protocol::models::ResponseItem; use codex_protocol::openai_models::InputModality; +use codex_protocol::openai_models::ModelAvailabilityNux as CoreModelAvailabilityNux; use codex_protocol::openai_models::ReasoningEffort; use codex_protocol::openai_models::default_input_modalities; use codex_protocol::parse_command::ParsedCommand as CoreParsedCommand; @@ -1389,6 +1390,21 @@ pub struct ModelListParams { pub include_hidden: Option, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +pub struct ModelAvailabilityNux { + pub message: String, +} + +impl From for ModelAvailabilityNux { + fn from(value: CoreModelAvailabilityNux) -> Self { + Self { + message: value.message, + } + } +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] #[serde(rename_all = "camelCase")] #[ts(export_to = "v2/")] @@ -1396,6 +1412,8 @@ pub struct Model { pub id: String, pub model: String, pub upgrade: Option, + pub upgrade_info: Option, + pub availability_nux: Option, pub display_name: String, pub description: String, pub hidden: bool, @@ -1409,6 +1427,16 @@ pub struct Model { pub is_default: bool, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +pub struct ModelUpgradeInfo { + pub model: String, + pub upgrade_copy: Option, + pub model_link: Option, + pub migration_markdown: Option, +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] #[serde(rename_all = "camelCase")] #[ts(export_to = "v2/")] diff --git a/codex-rs/app-server/README.md b/codex-rs/app-server/README.md index 20b2dcf2c7b..60cf6d6107c 100644 --- a/codex-rs/app-server/README.md +++ b/codex-rs/app-server/README.md @@ -142,7 +142,7 @@ Example with notification opt-out: - `thread/realtime/stop` — stop the active realtime session for the thread (experimental); returns `{}`. - `review/start` — kick off Codex’s automated reviewer for a thread; responds like `turn/start` and emits `item/started`/`item/completed` notifications with `enteredReviewMode` and `exitedReviewMode` items, plus a final assistant `agentMessage` containing the review. - `command/exec` — run a single command under the server sandbox without starting a thread/turn (handy for utilities and validation). -- `model/list` — list available models (set `includeHidden: true` to include entries with `hidden: true`), with reasoning effort options and optional `upgrade` model ids. +- `model/list` — list available models (set `includeHidden: true` to include entries with `hidden: true`), with reasoning effort options, optional legacy `upgrade` model ids, optional `upgradeInfo` metadata (`model`, `upgradeCopy`, `modelLink`, `migrationMarkdown`), and optional `availabilityNux` metadata. - `experimentalFeature/list` — list feature flags with stage metadata (`beta`, `underDevelopment`, `stable`, etc.), enabled/default-enabled state, and cursor pagination. For non-beta flags, `displayName`/`description`/`announcement` are `null`. - `collaborationMode/list` — list available collaboration mode presets (experimental, no pagination). This response omits built-in developer instructions; clients should either pass `settings.developer_instructions: null` when setting a mode to use Codex's built-in instructions, or provide their own instructions explicitly. - `skills/list` — list skills for one or more `cwd` values (optional `forceReload`). diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index 0d3b98f03cd..1f6e1ac0f1f 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -560,7 +560,12 @@ impl CodexMessageProcessor { Ok((review_request, hint)) } - pub async fn process_request(&mut self, connection_id: ConnectionId, request: ClientRequest) { + pub async fn process_request( + &mut self, + connection_id: ConnectionId, + request: ClientRequest, + app_server_client_name: Option, + ) { let to_connection_request_id = |request_id| ConnectionRequestId { connection_id, request_id, @@ -647,8 +652,12 @@ impl CodexMessageProcessor { .await; } ClientRequest::TurnStart { request_id, params } => { - self.turn_start(to_connection_request_id(request_id), params) - .await; + self.turn_start( + to_connection_request_id(request_id), + params, + app_server_client_name.clone(), + ) + .await; } ClientRequest::TurnSteer { request_id, params } => { self.turn_steer(to_connection_request_id(request_id), params) @@ -767,12 +776,20 @@ impl CodexMessageProcessor { .await; } ClientRequest::SendUserMessage { request_id, params } => { - self.send_user_message(to_connection_request_id(request_id), params) - .await; + self.send_user_message( + to_connection_request_id(request_id), + params, + app_server_client_name.clone(), + ) + .await; } ClientRequest::SendUserTurn { request_id, params } => { - self.send_user_turn(to_connection_request_id(request_id), params) - .await; + self.send_user_turn( + to_connection_request_id(request_id), + params, + app_server_client_name.clone(), + ) + .await; } ClientRequest::InterruptConversation { request_id, params } => { self.interrupt_conversation(to_connection_request_id(request_id), params) @@ -4152,6 +4169,7 @@ impl CodexMessageProcessor { http_headers, env_http_headers, scopes.as_deref().unwrap_or_default(), + server.oauth_resource.as_deref(), timeout_secs, config.mcp_oauth_callback_port, config.mcp_oauth_callback_url.as_deref(), @@ -5062,6 +5080,7 @@ impl CodexMessageProcessor { &self, request_id: ConnectionRequestId, params: SendUserMessageParams, + app_server_client_name: Option, ) { let SendUserMessageParams { conversation_id, @@ -5080,6 +5099,12 @@ impl CodexMessageProcessor { self.outgoing.send_error(request_id, error).await; return; }; + if let Err(error) = + Self::set_app_server_client_name(conversation.as_ref(), app_server_client_name).await + { + self.outgoing.send_error(request_id, error).await; + return; + } let mapped_items: Vec = items .into_iter() @@ -5110,7 +5135,12 @@ impl CodexMessageProcessor { .await; } - async fn send_user_turn(&self, request_id: ConnectionRequestId, params: SendUserTurnParams) { + async fn send_user_turn( + &self, + request_id: ConnectionRequestId, + params: SendUserTurnParams, + app_server_client_name: Option, + ) { let SendUserTurnParams { conversation_id, items, @@ -5136,6 +5166,12 @@ impl CodexMessageProcessor { self.outgoing.send_error(request_id, error).await; return; }; + if let Err(error) = + Self::set_app_server_client_name(conversation.as_ref(), app_server_client_name).await + { + self.outgoing.send_error(request_id, error).await; + return; + } let mapped_items: Vec = items .into_iter() @@ -5160,7 +5196,7 @@ impl CodexMessageProcessor { sandbox_policy, model, effort, - summary, + summary: Some(summary), final_output_json_schema: output_schema, collaboration_mode: None, personality: None, @@ -5249,6 +5285,7 @@ impl CodexMessageProcessor { connectors::list_cached_accessible_connectors_from_mcp_tools(&config), connectors::list_cached_all_connectors(&config) ); + let cached_all_connectors = all_connectors.clone(); let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); @@ -5275,6 +5312,19 @@ impl CodexMessageProcessor { let app_list_deadline = tokio::time::Instant::now() + APP_LIST_LOAD_TIMEOUT; let mut accessible_loaded = false; let mut all_loaded = false; + let mut last_notified_apps = None; + + if accessible_connectors.is_some() || all_connectors.is_some() { + let merged = connectors::with_app_enabled_state( + Self::merge_loaded_apps( + all_connectors.as_deref(), + accessible_connectors.as_deref(), + ), + &config, + ); + Self::send_app_list_updated_notification(&outgoing, merged.clone()).await; + last_notified_apps = Some(merged); + } loop { let result = match tokio::time::timeout_at(app_list_deadline, rx.recv()).await { @@ -5331,14 +5381,30 @@ impl CodexMessageProcessor { } } + let showing_interim_force_refetch = force_refetch && !(accessible_loaded && all_loaded); + let all_connectors_for_update = + if showing_interim_force_refetch && cached_all_connectors.is_some() { + cached_all_connectors.as_deref() + } else { + all_connectors.as_deref() + }; + let accessible_connectors_for_update = + if showing_interim_force_refetch && !accessible_loaded { + None + } else { + accessible_connectors.as_deref() + }; let merged = connectors::with_app_enabled_state( Self::merge_loaded_apps( - all_connectors.as_deref(), - accessible_connectors.as_deref(), + all_connectors_for_update, + accessible_connectors_for_update, ), &config, ); - Self::send_app_list_updated_notification(&outgoing, merged.clone()).await; + if last_notified_apps.as_ref() != Some(&merged) { + Self::send_app_list_updated_notification(&outgoing, merged.clone()).await; + last_notified_apps = Some(merged.clone()); + } if accessible_loaded && all_loaded { match Self::paginate_apps(merged.as_slice(), start, limit) { @@ -5607,7 +5673,12 @@ impl CodexMessageProcessor { let _ = conversation.submit(Op::Interrupt).await; } - async fn turn_start(&self, request_id: ConnectionRequestId, params: TurnStartParams) { + async fn turn_start( + &self, + request_id: ConnectionRequestId, + params: TurnStartParams, + app_server_client_name: Option, + ) { if let Err(error) = Self::validate_v2_input_limit(¶ms.input) { self.outgoing.send_error(request_id, error).await; return; @@ -5619,6 +5690,12 @@ impl CodexMessageProcessor { return; } }; + if let Err(error) = + Self::set_app_server_client_name(thread.as_ref(), app_server_client_name).await + { + self.outgoing.send_error(request_id, error).await; + return; + } let collaboration_modes_config = CollaborationModesConfig { default_mode_request_user_input: thread.enabled(Feature::DefaultModeRequestUserInput), @@ -5700,6 +5777,20 @@ impl CodexMessageProcessor { } } + async fn set_app_server_client_name( + thread: &CodexThread, + app_server_client_name: Option, + ) -> Result<(), JSONRPCErrorError> { + thread + .set_app_server_client_name(app_server_client_name) + .await + .map_err(|err| JSONRPCErrorError { + code: INTERNAL_ERROR_CODE, + message: format!("failed to set app server client name: {err}"), + data: None, + }) + } + async fn turn_steer(&self, request_id: ConnectionRequestId, params: TurnSteerParams) { let (_, thread) = match self.load_thread(¶ms.thread_id).await { Ok(v) => v, diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index 9f1b182cfa0..79f845ad4b3 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -140,6 +140,7 @@ pub(crate) struct ConnectionSessionState { pub(crate) initialized: bool, pub(crate) experimental_api_enabled: bool, pub(crate) opted_out_notification_methods: HashSet, + pub(crate) app_server_client_name: Option, } pub(crate) struct MessageProcessorArgs { @@ -329,6 +330,7 @@ impl MessageProcessor { if let Ok(mut suffix) = USER_AGENT_SUFFIX.lock() { *suffix = Some(user_agent_suffix); } + session.app_server_client_name = Some(name.clone()); let user_agent = get_codex_user_agent(); let response = InitializeResponse { user_agent }; @@ -430,7 +432,7 @@ impl MessageProcessor { // inline the full `CodexMessageProcessor::process_request` future, which // can otherwise push worker-thread stack usage over the edge. self.codex_message_processor - .process_request(connection_id, other) + .process_request(connection_id, other, session.app_server_client_name.clone()) .boxed() .await; } diff --git a/codex-rs/app-server/src/models.rs b/codex-rs/app-server/src/models.rs index 1594c66229f..6dbe77455f9 100644 --- a/codex-rs/app-server/src/models.rs +++ b/codex-rs/app-server/src/models.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use codex_app_server_protocol::Model; +use codex_app_server_protocol::ModelUpgradeInfo; use codex_app_server_protocol::ReasoningEffortOption; use codex_core::ThreadManager; use codex_core::models_manager::manager::RefreshStrategy; @@ -24,7 +25,14 @@ fn model_from_preset(preset: ModelPreset) -> Model { Model { id: preset.id.to_string(), model: preset.model.to_string(), - upgrade: preset.upgrade.map(|upgrade| upgrade.id), + upgrade: preset.upgrade.as_ref().map(|upgrade| upgrade.id.clone()), + upgrade_info: preset.upgrade.as_ref().map(|upgrade| ModelUpgradeInfo { + model: upgrade.id.clone(), + upgrade_copy: upgrade.upgrade_copy.clone(), + model_link: upgrade.model_link.clone(), + migration_markdown: upgrade.migration_markdown.clone(), + }), + availability_nux: preset.availability_nux.map(Into::into), display_name: preset.display_name.to_string(), description: preset.description.to_string(), hidden: !preset.show_in_picker, diff --git a/codex-rs/app-server/tests/common/models_cache.rs b/codex-rs/app-server/tests/common/models_cache.rs index 218a1d0e406..0de8fda5f2d 100644 --- a/codex-rs/app-server/tests/common/models_cache.rs +++ b/codex-rs/app-server/tests/common/models_cache.rs @@ -1,6 +1,7 @@ use chrono::DateTime; use chrono::Utc; use codex_core::test_support::all_model_presets; +use codex_protocol::config_types::ReasoningSummary; use codex_protocol::openai_models::ConfigShellToolType; use codex_protocol::openai_models::ModelInfo; use codex_protocol::openai_models::ModelPreset; @@ -30,8 +31,10 @@ fn preset_to_info(preset: &ModelPreset, priority: i32) -> ModelInfo { base_instructions: "base instructions".to_string(), model_messages: None, supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::Auto, support_verbosity: false, default_verbosity: None, + availability_nux: None, apply_patch_tool_type: None, truncation_policy: TruncationPolicyConfig::bytes(10_000), supports_parallel_tool_calls: false, diff --git a/codex-rs/app-server/tests/suite/send_message.rs b/codex-rs/app-server/tests/suite/send_message.rs index e5c023d9d17..e755a11b772 100644 --- a/codex-rs/app-server/tests/suite/send_message.rs +++ b/codex-rs/app-server/tests/suite/send_message.rs @@ -620,6 +620,8 @@ fn append_rollout_turn_context(path: &Path, timestamp: &str, model: &str) -> std item: RolloutItem::TurnContext(TurnContextItem { turn_id: None, cwd: PathBuf::from("/"), + current_date: None, + timezone: None, approval_policy: AskForApproval::Never, sandbox_policy: SandboxPolicy::DangerFullAccess, network: None, diff --git a/codex-rs/app-server/tests/suite/v2/app_list.rs b/codex-rs/app-server/tests/suite/v2/app_list.rs index a9bc14e9328..6a3243ad560 100644 --- a/codex-rs/app-server/tests/suite/v2/app_list.rs +++ b/codex-rs/app-server/tests/suite/v2/app_list.rs @@ -994,6 +994,41 @@ async fn list_apps_force_refetch_patches_updates_from_cached_snapshots() -> Resu let first_update = read_app_list_updated_notification(&mut mcp).await?; assert_eq!( first_update.data, + vec![ + AppInfo { + id: "beta".to_string(), + name: "Beta App".to_string(), + description: Some("Beta v1".to_string()), + logo_url: None, + logo_url_dark: None, + distribution_channel: None, + branding: None, + app_metadata: None, + labels: None, + install_url: Some("https://chatgpt.com/apps/beta-app/beta".to_string()), + is_accessible: true, + is_enabled: true, + }, + AppInfo { + id: "alpha".to_string(), + name: "Alpha".to_string(), + description: Some("Alpha v1".to_string()), + logo_url: None, + logo_url_dark: None, + distribution_channel: None, + branding: None, + app_metadata: None, + labels: None, + install_url: Some("https://chatgpt.com/apps/alpha/alpha".to_string()), + is_accessible: false, + is_enabled: true, + }, + ] + ); + + let second_update = read_app_list_updated_notification(&mut mcp).await?; + assert_eq!( + second_update.data, vec![ AppInfo { id: "alpha".to_string(), @@ -1040,8 +1075,8 @@ async fn list_apps_force_refetch_patches_updates_from_cached_snapshots() -> Resu is_accessible: false, is_enabled: true, }]; - let second_update = read_app_list_updated_notification(&mut mcp).await?; - assert_eq!(second_update.data, expected_final); + let third_update = read_app_list_updated_notification(&mut mcp).await?; + assert_eq!(third_update.data, expected_final); let refetch_response: JSONRPCResponse = timeout( DEFAULT_TIMEOUT, diff --git a/codex-rs/app-server/tests/suite/v2/initialize.rs b/codex-rs/app-server/tests/suite/v2/initialize.rs index 2edc83a7f49..52da448a4c8 100644 --- a/codex-rs/app-server/tests/suite/v2/initialize.rs +++ b/codex-rs/app-server/tests/suite/v2/initialize.rs @@ -1,16 +1,24 @@ use anyhow::Result; use app_test_support::McpProcess; +use app_test_support::create_final_assistant_message_sse_response; use app_test_support::create_mock_responses_server_sequence_unchecked; use app_test_support::to_response; use codex_app_server_protocol::ClientInfo; use codex_app_server_protocol::InitializeCapabilities; use codex_app_server_protocol::InitializeResponse; use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::JSONRPCResponse; use codex_app_server_protocol::RequestId; use codex_app_server_protocol::ThreadStartParams; use codex_app_server_protocol::ThreadStartResponse; +use codex_app_server_protocol::TurnStartParams; +use codex_app_server_protocol::TurnStartResponse; +use codex_app_server_protocol::UserInput as V2UserInput; +use core_test_support::fs_wait; use pretty_assertions::assert_eq; +use serde_json::Value; use std::path::Path; +use std::time::Duration; use tempfile::TempDir; use tokio::time::timeout; @@ -178,11 +186,100 @@ async fn initialize_opt_out_notification_methods_filters_notifications() -> Resu Ok(()) } +#[tokio::test] +async fn turn_start_notify_payload_includes_initialize_client_name() -> Result<()> { + let responses = vec![create_final_assistant_message_sse_response("Done")?]; + let server = create_mock_responses_server_sequence_unchecked(responses).await; + let codex_home = TempDir::new()?; + let notify_script = codex_home.path().join("notify.py"); + std::fs::write( + ¬ify_script, + r#"from pathlib import Path +import sys + +Path(__file__).with_name("notify.json").write_text(sys.argv[-1], encoding="utf-8") +"#, + )?; + let notify_file = codex_home.path().join("notify.json"); + let notify_script = notify_script + .to_str() + .expect("notify script path should be valid UTF-8"); + create_config_toml_with_extra( + codex_home.path(), + &server.uri(), + "never", + &format!( + "notify = [\"python3\", {}]", + toml_basic_string(notify_script) + ), + )?; + + let mut mcp = McpProcess::new(codex_home.path()).await?; + timeout( + DEFAULT_READ_TIMEOUT, + mcp.initialize_with_client_info(ClientInfo { + name: "xcode".to_string(), + title: Some("Xcode".to_string()), + version: "1.0.0".to_string(), + }), + ) + .await??; + + let thread_req = mcp + .send_thread_start_request(ThreadStartParams::default()) + .await?; + let thread_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(thread_req)), + ) + .await??; + let ThreadStartResponse { thread, .. } = to_response(thread_resp)?; + + let turn_req = mcp + .send_turn_start_request(TurnStartParams { + thread_id: thread.id, + input: vec![V2UserInput::Text { + text: "Hello".to_string(), + text_elements: Vec::new(), + }], + ..Default::default() + }) + .await?; + let turn_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(turn_req)), + ) + .await??; + let _: TurnStartResponse = to_response(turn_resp)?; + + timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("turn/completed"), + ) + .await??; + + fs_wait::wait_for_path_exists(¬ify_file, Duration::from_secs(5)).await?; + let payload_raw = tokio::fs::read_to_string(¬ify_file).await?; + let payload: Value = serde_json::from_str(&payload_raw)?; + assert_eq!(payload["client"], "xcode"); + + Ok(()) +} + // Helper to create a config.toml pointing at the mock model server. fn create_config_toml( codex_home: &Path, server_uri: &str, approval_policy: &str, +) -> std::io::Result<()> { + create_config_toml_with_extra(codex_home, server_uri, approval_policy, "") +} + +fn create_config_toml_with_extra( + codex_home: &Path, + server_uri: &str, + approval_policy: &str, + extra: &str, ) -> std::io::Result<()> { let config_toml = codex_home.join("config.toml"); std::fs::write( @@ -195,6 +292,8 @@ sandbox_mode = "read-only" model_provider = "mock_provider" +{extra} + [model_providers.mock_provider] name = "Mock provider for test" base_url = "{server_uri}/v1" @@ -205,3 +304,7 @@ stream_max_retries = 0 ), ) } + +fn toml_basic_string(value: &str) -> String { + format!("\"{}\"", value.replace('\\', "\\\\").replace('"', "\\\"")) +} diff --git a/codex-rs/app-server/tests/suite/v2/model_list.rs b/codex-rs/app-server/tests/suite/v2/model_list.rs index a71a8a3377e..eba8905a771 100644 --- a/codex-rs/app-server/tests/suite/v2/model_list.rs +++ b/codex-rs/app-server/tests/suite/v2/model_list.rs @@ -9,6 +9,7 @@ use codex_app_server_protocol::JSONRPCResponse; use codex_app_server_protocol::Model; use codex_app_server_protocol::ModelListParams; use codex_app_server_protocol::ModelListResponse; +use codex_app_server_protocol::ModelUpgradeInfo; use codex_app_server_protocol::ReasoningEffortOption; use codex_app_server_protocol::RequestId; use codex_protocol::openai_models::ModelPreset; @@ -24,6 +25,13 @@ fn model_from_preset(preset: &ModelPreset) -> Model { id: preset.id.clone(), model: preset.model.clone(), upgrade: preset.upgrade.as_ref().map(|upgrade| upgrade.id.clone()), + upgrade_info: preset.upgrade.as_ref().map(|upgrade| ModelUpgradeInfo { + model: upgrade.id.clone(), + upgrade_copy: upgrade.upgrade_copy.clone(), + model_link: upgrade.model_link.clone(), + migration_markdown: upgrade.migration_markdown.clone(), + }), + availability_nux: preset.availability_nux.clone().map(Into::into), display_name: preset.display_name.clone(), description: preset.description.clone(), hidden: !preset.show_in_picker, diff --git a/codex-rs/chatgpt/src/connectors.rs b/codex-rs/chatgpt/src/connectors.rs index a3470ff1b9d..620a93872ec 100644 --- a/codex-rs/chatgpt/src/connectors.rs +++ b/codex-rs/chatgpt/src/connectors.rs @@ -298,12 +298,7 @@ fn merge_directory_app(existing: &mut DirectoryApp, incoming: DirectoryApp) { .as_deref() .map(|value| !value.trim().is_empty()) .unwrap_or(false); - let existing_description_present = existing - .description - .as_deref() - .map(|value| !value.trim().is_empty()) - .unwrap_or(false); - if !existing_description_present && incoming_description_present { + if incoming_description_present { existing.description = description; } diff --git a/codex-rs/cli/src/main.rs b/codex-rs/cli/src/main.rs index 42ded11c06a..c271fe0b9eb 100644 --- a/codex-rs/cli/src/main.rs +++ b/codex-rs/cli/src/main.rs @@ -808,6 +808,7 @@ async fn cli_main(arg0_paths: Arg0DispatchPaths) -> anyhow::Result<()> { stage_width = stage_width.max(stage.len()); rows.push((name, stage, enabled)); } + rows.sort_unstable_by_key(|(name, _, _)| *name); for (name, stage, enabled) in rows { println!("{name: Re enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }; servers.insert(name.clone(), new_entry); @@ -272,6 +273,7 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re oauth_config.http_headers, oauth_config.env_http_headers, &Vec::new(), + None, config.mcp_oauth_callback_port, config.mcp_oauth_callback_url.as_deref(), ) @@ -356,6 +358,7 @@ async fn run_login(config_overrides: &CliConfigOverrides, login_args: LoginArgs) http_headers, env_http_headers, &scopes, + server.oauth_resource.as_deref(), config.mcp_oauth_callback_port, config.mcp_oauth_callback_url.as_deref(), ) diff --git a/codex-rs/cli/tests/features.rs b/codex-rs/cli/tests/features.rs index 6b2ed72f8af..17a7eff679c 100644 --- a/codex-rs/cli/tests/features.rs +++ b/codex-rs/cli/tests/features.rs @@ -2,6 +2,7 @@ use std::path::Path; use anyhow::Result; use predicates::str::contains; +use pretty_assertions::assert_eq; use tempfile::TempDir; fn codex_command(codex_home: &Path) -> Result { @@ -58,3 +59,33 @@ async fn features_enable_under_development_feature_prints_warning() -> Result<() Ok(()) } + +#[tokio::test] +async fn features_list_is_sorted_alphabetically_by_feature_name() -> Result<()> { + let codex_home = TempDir::new()?; + + let mut cmd = codex_command(codex_home.path())?; + let output = cmd + .args(["features", "list"]) + .assert() + .success() + .get_output() + .stdout + .clone(); + let stdout = String::from_utf8(output)?; + + let actual_names = stdout + .lines() + .map(|line| { + line.split_once(" ") + .map(|(name, _)| name.trim_end().to_string()) + .expect("feature list output should contain aligned columns") + }) + .collect::>(); + let mut expected_names = actual_names.clone(); + expected_names.sort(); + + assert_eq!(actual_names, expected_names); + + Ok(()) +} diff --git a/codex-rs/cloud-requirements/src/lib.rs b/codex-rs/cloud-requirements/src/lib.rs index 6f6bf3b6dea..d89ff53524f 100644 --- a/codex-rs/cloud-requirements/src/lib.rs +++ b/codex-rs/cloud-requirements/src/lib.rs @@ -28,17 +28,21 @@ use serde::Serialize; use sha2::Sha256; use std::path::PathBuf; use std::sync::Arc; +use std::sync::Mutex; +use std::sync::OnceLock; use std::time::Duration; use std::time::Instant; use thiserror::Error; use tokio::fs; +use tokio::task::JoinHandle; use tokio::time::sleep; use tokio::time::timeout; const CLOUD_REQUIREMENTS_TIMEOUT: Duration = Duration::from_secs(15); const CLOUD_REQUIREMENTS_MAX_ATTEMPTS: usize = 5; const CLOUD_REQUIREMENTS_CACHE_FILENAME: &str = "cloud-requirements-cache.json"; -const CLOUD_REQUIREMENTS_CACHE_TTL: Duration = Duration::from_secs(60 * 60); +const CLOUD_REQUIREMENTS_CACHE_REFRESH_INTERVAL: Duration = Duration::from_secs(5 * 60); +const CLOUD_REQUIREMENTS_CACHE_TTL: Duration = Duration::from_secs(30 * 60); const CLOUD_REQUIREMENTS_CACHE_WRITE_HMAC_KEY: &[u8] = b"codex-cloud-requirements-cache-v3-064f8542-75b4-494c-a294-97d3ce597271"; const CLOUD_REQUIREMENTS_CACHE_READ_HMAC_KEYS: &[&[u8]] = @@ -46,6 +50,11 @@ const CLOUD_REQUIREMENTS_CACHE_READ_HMAC_KEYS: &[&[u8]] = type HmacSha256 = Hmac; +fn refresher_task_slot() -> &'static Mutex>> { + static REFRESHER_TASK: OnceLock>>> = OnceLock::new(); + REFRESHER_TASK.get_or_init(|| Mutex::new(None)) +} + #[derive(Clone, Copy, Debug, Eq, PartialEq)] enum FetchCloudRequirementsStatus { BackendClientInit, @@ -188,6 +197,7 @@ impl RequirementsFetcher for BackendRequirementsFetcher { } } +#[derive(Clone)] struct CloudRequirementsService { auth_manager: Arc, fetcher: Arc, @@ -325,6 +335,54 @@ impl CloudRequirementsService { None } + async fn refresh_cache_in_background(&self) { + loop { + sleep(CLOUD_REQUIREMENTS_CACHE_REFRESH_INTERVAL).await; + match timeout(self.timeout, self.refresh_cache()).await { + Ok(true) => {} + Ok(false) => break, + Err(_) => { + tracing::warn!( + "Timed out refreshing cloud requirements cache from remote; keeping existing cache" + ); + } + } + } + } + + async fn refresh_cache(&self) -> bool { + let Some(auth) = self.auth_manager.auth().await else { + return false; + }; + if !auth.is_chatgpt_auth() + || !matches!( + auth.account_plan_type(), + Some(PlanType::Business | PlanType::Enterprise) + ) + { + return false; + } + + let token_data = auth.get_token_data().ok(); + let chatgpt_user_id = token_data + .as_ref() + .and_then(|token_data| token_data.id_token.chatgpt_user_id.as_deref()); + let account_id = auth.get_account_id(); + let account_id = account_id.as_deref(); + + if self + .fetch_with_retries(&auth, chatgpt_user_id, account_id) + .await + .is_none() + { + tracing::warn!( + path = %self.cache_path.display(), + "Failed to refresh cloud requirements cache from remote" + ); + } + true + } + async fn load_cache( &self, chatgpt_user_id: Option<&str>, @@ -452,7 +510,17 @@ pub fn cloud_requirements_loader( codex_home, CLOUD_REQUIREMENTS_TIMEOUT, ); + let refresh_service = service.clone(); let task = tokio::spawn(async move { service.fetch_with_timeout().await }); + let refresh_task = + tokio::spawn(async move { refresh_service.refresh_cache_in_background().await }); + let mut refresher_guard = refresher_task_slot().lock().unwrap_or_else(|err| { + tracing::warn!("cloud requirements refresher task slot was poisoned"); + err.into_inner() + }); + if let Some(existing_task) = refresher_guard.replace(refresh_task) { + existing_task.abort(); + } CloudRequirementsLoader::new(async move { task.await .inspect_err(|err| tracing::warn!(error = %err, "Cloud requirements task failed")) @@ -1052,7 +1120,11 @@ mod tests { let cache_file: CloudRequirementsCacheFile = serde_json::from_str(&std::fs::read_to_string(path).expect("read cache")) .expect("parse cache"); - assert!(cache_file.signed_payload.expires_at > Utc::now()); + assert!( + cache_file.signed_payload.expires_at + <= cache_file.signed_payload.cached_at + ChronoDuration::minutes(30) + ); + assert!(cache_file.signed_payload.expires_at > cache_file.signed_payload.cached_at); assert!(cache_file.signed_payload.cached_at <= Utc::now()); assert_eq!( cache_file.signed_payload.chatgpt_user_id, @@ -1130,4 +1202,57 @@ mod tests { CLOUD_REQUIREMENTS_MAX_ATTEMPTS ); } + + #[tokio::test] + async fn refresh_from_remote_updates_cached_cloud_requirements() { + let codex_home = tempdir().expect("tempdir"); + let fetcher = Arc::new(SequenceFetcher::new(vec![ + Ok(Some("allowed_approval_policies = [\"never\"]".to_string())), + Ok(Some( + "allowed_approval_policies = [\"on-request\"]".to_string(), + )), + ])); + let service = CloudRequirementsService::new( + auth_manager_with_plan("business"), + fetcher, + codex_home.path().to_path_buf(), + CLOUD_REQUIREMENTS_TIMEOUT, + ); + + assert_eq!( + service.fetch().await, + Some(ConfigRequirementsToml { + allowed_approval_policies: Some(vec![AskForApproval::Never]), + allowed_sandbox_modes: None, + allowed_web_search_modes: None, + mcp_servers: None, + rules: None, + enforce_residency: None, + network: None, + }) + ); + + service.refresh_cache().await; + + let path = codex_home.path().join(CLOUD_REQUIREMENTS_CACHE_FILENAME); + let cache_file: CloudRequirementsCacheFile = + serde_json::from_str(&std::fs::read_to_string(path).expect("read cache")) + .expect("parse cache"); + assert_eq!( + cache_file + .signed_payload + .contents + .as_deref() + .and_then(|contents| parse_cloud_requirements(contents).ok().flatten()), + Some(ConfigRequirementsToml { + allowed_approval_policies: Some(vec![AskForApproval::OnRequest]), + allowed_sandbox_modes: None, + allowed_web_search_modes: None, + mcp_servers: None, + rules: None, + enforce_residency: None, + network: None, + }) + ); + } } diff --git a/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs b/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs index 48cc1d7002f..97d0b9e320f 100644 --- a/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs +++ b/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs @@ -25,6 +25,8 @@ use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::tungstenite::Error as WsError; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tracing::debug; +use tracing::error; use tracing::info; use tracing::trace; use tungstenite::protocol::WebSocketConfig; @@ -62,15 +64,23 @@ impl WsStream { }; match command { WsCommand::Send { message, tx_result } => { + debug!("realtime websocket sending message"); let result = inner.send(message).await; let should_break = result.is_err(); + if let Err(err) = &result { + error!("realtime websocket send failed: {err}"); + } let _ = tx_result.send(result); if should_break { break; } } WsCommand::Close { tx_result } => { + info!("realtime websocket sending close"); let result = inner.close(None).await; + if let Err(err) = &result { + error!("realtime websocket close failed: {err}"); + } let _ = tx_result.send(result); break; } @@ -82,7 +92,9 @@ impl WsStream { }; match message { Ok(Message::Ping(payload)) => { + trace!(payload_len = payload.len(), "realtime websocket received ping"); if let Err(err) = inner.send(Message::Pong(payload)).await { + error!("realtime websocket failed to send pong: {err}"); let _ = tx_message.send(Err(err)); break; } @@ -93,6 +105,24 @@ impl WsStream { | Message::Close(_) | Message::Frame(_))) => { let is_close = matches!(message, Message::Close(_)); + match &message { + Message::Text(_) => trace!("realtime websocket received text frame"), + Message::Binary(binary) => { + error!( + payload_len = binary.len(), + "realtime websocket received unexpected binary frame" + ); + } + Message::Close(frame) => info!( + "realtime websocket received close frame: code={:?} reason={:?}", + frame.as_ref().map(|frame| frame.code), + frame.as_ref().map(|frame| frame.reason.as_str()) + ), + Message::Frame(_) => { + trace!("realtime websocket received raw frame"); + } + Message::Ping(_) | Message::Pong(_) => {} + } if tx_message.send(Ok(message)).is_err() { break; } @@ -101,6 +131,7 @@ impl WsStream { } } Err(err) => { + error!("realtime websocket receive failed: {err}"); let _ = tx_message.send(Err(err)); break; } @@ -108,6 +139,7 @@ impl WsStream { } } } + info!("realtime websocket pump exiting"); }); ( @@ -298,7 +330,7 @@ impl RealtimeWebsocketWriter { async fn send_json(&self, message: RealtimeOutboundMessage) -> Result<(), ApiError> { let payload = serde_json::to_string(&message) .map_err(|err| ApiError::Stream(format!("failed to encode realtime request: {err}")))?; - trace!("realtime websocket request: {payload}"); + debug!(?message, "realtime websocket request"); if self.is_closed.load(Ordering::SeqCst) { return Err(ApiError::Stream( @@ -325,12 +357,14 @@ impl RealtimeWebsocketEvents { Some(Ok(msg)) => msg, Some(Err(err)) => { self.is_closed.store(true, Ordering::SeqCst); + error!("realtime websocket read failed: {err}"); return Err(ApiError::Stream(format!( "failed to read websocket message: {err}" ))); } None => { self.is_closed.store(true, Ordering::SeqCst); + info!("realtime websocket event stream ended"); return Ok(None); } }; @@ -338,11 +372,18 @@ impl RealtimeWebsocketEvents { match msg { Message::Text(text) => { if let Some(event) = parse_realtime_event(&text) { + debug!(?event, "realtime websocket parsed event"); return Ok(Some(event)); } + debug!("realtime websocket ignored unsupported text frame"); } - Message::Close(_) => { + Message::Close(frame) => { self.is_closed.store(true, Ordering::SeqCst); + info!( + "realtime websocket closed: code={:?} reason={:?}", + frame.as_ref().map(|frame| frame.code), + frame.as_ref().map(|frame| frame.reason.as_str()) + ); return Ok(None); } Message::Binary(_) => { @@ -383,15 +424,24 @@ impl RealtimeWebsocketClient { request.headers_mut().extend(headers); info!("connecting realtime websocket: {ws_url}"); - let (stream, _) = + let (stream, response) = tokio_tungstenite::connect_async_with_config(request, Some(websocket_config()), false) .await .map_err(|err| { ApiError::Stream(format!("failed to connect realtime websocket: {err}")) })?; + info!( + ws_url = %ws_url, + status = %response.status(), + "realtime websocket connected" + ); let (stream, rx_message) = WsStream::new(stream); let connection = RealtimeWebsocketConnection::new(stream, rx_message); + debug!( + conversation_id = config.session_id.as_deref().unwrap_or(""), + "realtime websocket sending session.create" + ); connection .send_session_create(config.prompt, config.session_id) .await?; diff --git a/codex-rs/codex-api/tests/models_integration.rs b/codex-rs/codex-api/tests/models_integration.rs index 9f7e3806650..2b61f0de60a 100644 --- a/codex-rs/codex-api/tests/models_integration.rs +++ b/codex-rs/codex-api/tests/models_integration.rs @@ -3,6 +3,7 @@ use codex_api::ModelsClient; use codex_api::provider::Provider; use codex_api::provider::RetryConfig; use codex_client::ReqwestTransport; +use codex_protocol::config_types::ReasoningSummary; use codex_protocol::openai_models::ConfigShellToolType; use codex_protocol::openai_models::ModelInfo; use codex_protocol::openai_models::ModelVisibility; @@ -78,8 +79,10 @@ async fn models_client_hits_models_endpoint() { base_instructions: "base instructions".to_string(), model_messages: None, supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::Auto, support_verbosity: false, default_verbosity: None, + availability_nux: None, apply_patch_tool_type: None, truncation_policy: TruncationPolicyConfig::bytes(10_000), supports_parallel_tool_calls: false, diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index 47731291a42..87f18eb8000 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -61,6 +61,7 @@ env-flags = { workspace = true } eventsource-stream = { workspace = true } futures = { workspace = true } http = { workspace = true } +iana-time-zone = { workspace = true } indexmap = { workspace = true } keyring = { workspace = true, features = ["crypto-rust"] } libc = { workspace = true } diff --git a/codex-rs/core/config.schema.json b/codex-rs/core/config.schema.json index e1f0fa5147d..d207d5bd8e9 100644 --- a/codex-rs/core/config.schema.json +++ b/codex-rs/core/config.schema.json @@ -630,6 +630,11 @@ "minimum": 0.0, "type": "integer" }, + "max_unused_days": { + "description": "Maximum number of days since a memory was last used before it becomes ineligible for phase 2 selection.", + "format": "int64", + "type": "integer" + }, "min_rollout_idle_hours": { "description": "Minimum idle time between last thread activity and memory creation (hours). > 12h recommended.", "format": "int64", @@ -1141,6 +1146,10 @@ }, "type": "object" }, + "oauth_resource": { + "default": null, + "type": "string" + }, "required": { "default": null, "type": "boolean" @@ -1174,6 +1183,18 @@ }, "type": "object" }, + "RealtimeAudioToml": { + "additionalProperties": false, + "properties": { + "microphone": { + "type": "string" + }, + "speaker": { + "type": "string" + } + }, + "type": "object" + }, "ReasoningEffort": { "description": "See https://platform.openai.com/docs/guides/reasoning?api-mode=responses#get-started-with-reasoning", "enum": [ @@ -1535,6 +1556,15 @@ "default": null, "description": "Settings for app-specific controls." }, + "audio": { + "allOf": [ + { + "$ref": "#/definitions/RealtimeAudioToml" + } + ], + "default": null, + "description": "Machine-local realtime audio device preferences used by realtime voice." + }, "background_terminal_max_timeout": { "description": "Maximum poll window for background terminal output (`write_stdin`), in milliseconds. Default: `300000` (5 minutes).", "format": "uint64", diff --git a/codex-rs/core/src/agent/control.rs b/codex-rs/core/src/agent/control.rs index f805a9e7425..8ea23622835 100644 --- a/codex-rs/core/src/agent/control.rs +++ b/codex-rs/core/src/agent/control.rs @@ -4,11 +4,17 @@ use crate::agent::status::is_final; use crate::error::CodexErr; use crate::error::Result as CodexResult; use crate::find_thread_path_by_id_str; +use crate::rollout::RolloutRecorder; +use crate::session_prefix::format_subagent_context_line; use crate::session_prefix::format_subagent_notification_message; use crate::state_db; use crate::thread_manager::ThreadManagerState; use codex_protocol::ThreadId; +use codex_protocol::models::FunctionCallOutputPayload; +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::InitialHistory; use codex_protocol::protocol::Op; +use codex_protocol::protocol::RolloutItem; use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::SubAgentSource; use codex_protocol::protocol::TokenUsage; @@ -18,6 +24,12 @@ use std::sync::Weak; use tokio::sync::watch; const AGENT_NAMES: &str = include_str!("agent_names.txt"); +const FORKED_SPAWN_AGENT_OUTPUT_MESSAGE: &str = "You are the newly spawned agent. The prior conversation history was forked from your parent agent. Treat the next user message as your new task, and use the forked history only as background context."; + +#[derive(Clone, Debug, Default)] +pub(crate) struct SpawnAgentOptions { + pub(crate) fork_parent_spawn_call_id: Option, +} fn agent_nickname_list() -> Vec<&'static str> { AGENT_NAMES @@ -57,6 +69,17 @@ impl AgentControl { config: crate::config::Config, items: Vec, session_source: Option, + ) -> CodexResult { + self.spawn_agent_with_options(config, items, session_source, SpawnAgentOptions::default()) + .await + } + + pub(crate) async fn spawn_agent_with_options( + &self, + config: crate::config::Config, + items: Vec, + session_source: Option, + options: SpawnAgentOptions, ) -> CodexResult { let state = self.upgrade()?; let mut reservation = self.state.reserve_spawn_slot(config.agent_max_threads)?; @@ -82,9 +105,75 @@ impl AgentControl { // The same `AgentControl` is sent to spawn the thread. let new_thread = match session_source { Some(session_source) => { - state - .spawn_new_thread_with_source(config, self.clone(), session_source, false, None) - .await? + if let Some(call_id) = options.fork_parent_spawn_call_id.as_ref() { + let SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, + .. + }) = session_source.clone() + else { + return Err(CodexErr::Fatal( + "spawn_agent fork requires a thread-spawn session source".to_string(), + )); + }; + let parent_thread = state.get_thread(parent_thread_id).await.ok(); + if let Some(parent_thread) = parent_thread.as_ref() { + // `record_conversation_items` only queues rollout writes asynchronously. + // Flush/materialize the live parent before snapshotting JSONL for a fork. + parent_thread + .codex + .session + .ensure_rollout_materialized() + .await; + parent_thread.codex.session.flush_rollout().await; + } + let rollout_path = parent_thread + .as_ref() + .and_then(|parent_thread| parent_thread.rollout_path()) + .or(find_thread_path_by_id_str( + config.codex_home.as_path(), + &parent_thread_id.to_string(), + ) + .await?) + .ok_or_else(|| { + CodexErr::Fatal(format!( + "parent thread rollout unavailable for fork: {parent_thread_id}" + )) + })?; + let mut forked_rollout_items = + RolloutRecorder::get_rollout_history(&rollout_path) + .await? + .get_rollout_items(); + let mut output = FunctionCallOutputPayload::from_text( + FORKED_SPAWN_AGENT_OUTPUT_MESSAGE.to_string(), + ); + output.success = Some(true); + forked_rollout_items.push(RolloutItem::ResponseItem( + ResponseItem::FunctionCallOutput { + call_id: call_id.clone(), + output, + }, + )); + let initial_history = InitialHistory::Forked(forked_rollout_items); + state + .fork_thread_with_source( + config, + initial_history, + self.clone(), + session_source, + false, + ) + .await? + } else { + state + .spawn_new_thread_with_source( + config, + self.clone(), + session_source, + false, + None, + ) + .await? + } } None => state.spawn_new_thread(config, self.clone()).await?, }; @@ -255,6 +344,40 @@ impl AgentControl { thread.total_token_usage().await } + pub(crate) async fn format_environment_context_subagents( + &self, + parent_thread_id: ThreadId, + ) -> String { + let Ok(state) = self.upgrade() else { + return String::new(); + }; + + let mut agents = Vec::new(); + for thread_id in state.list_thread_ids().await { + let Ok(thread) = state.get_thread(thread_id).await else { + continue; + }; + let snapshot = thread.config_snapshot().await; + let SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id: agent_parent_thread_id, + agent_nickname, + .. + }) = snapshot.session_source + else { + continue; + }; + if agent_parent_thread_id != parent_thread_id { + continue; + } + agents.push(format_subagent_context_line( + &thread_id.to_string(), + agent_nickname.as_deref(), + )); + } + agents.sort(); + agents.join("\n") + } + /// Starts a detached watcher for sub-agents spawned from another thread. /// /// This is only enabled for `SubAgentSource::ThreadSpawn`, where a parent thread exists and @@ -421,6 +544,21 @@ mod tests { }) } + /// Returns true when any message item contains `needle` in a text span. + fn history_contains_text(history_items: &[ResponseItem], needle: &str) -> bool { + history_items.iter().any(|item| { + let ResponseItem::Message { content, .. } = item else { + return false; + }; + content.iter().any(|content_item| match content_item { + ContentItem::InputText { text } | ContentItem::OutputText { text } => { + text.contains(needle) + } + ContentItem::InputImage { .. } => false, + }) + }) + } + async fn wait_for_subagent_notification(parent_thread: &Arc) -> bool { let wait = async { loop { @@ -673,6 +811,242 @@ mod tests { assert_eq!(captured, Some(expected)); } + #[tokio::test] + async fn spawn_agent_can_fork_parent_thread_history() { + let harness = AgentControlHarness::new().await; + let (parent_thread_id, parent_thread) = harness.start_thread().await; + parent_thread + .inject_user_message_without_turn("parent seed context".to_string()) + .await; + let turn_context = parent_thread.codex.session.new_default_turn().await; + let parent_spawn_call_id = "spawn-call-history".to_string(); + let parent_spawn_call = ResponseItem::FunctionCall { + id: None, + name: "spawn_agent".to_string(), + arguments: "{}".to_string(), + call_id: parent_spawn_call_id.clone(), + }; + parent_thread + .codex + .session + .record_conversation_items(turn_context.as_ref(), &[parent_spawn_call]) + .await; + parent_thread + .codex + .session + .ensure_rollout_materialized() + .await; + parent_thread.codex.session.flush_rollout().await; + + let child_thread_id = harness + .control + .spawn_agent_with_options( + harness.config.clone(), + text_input("child task"), + Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, + depth: 1, + agent_nickname: None, + agent_role: None, + })), + SpawnAgentOptions { + fork_parent_spawn_call_id: Some(parent_spawn_call_id), + }, + ) + .await + .expect("forked spawn should succeed"); + + let child_thread = harness + .manager + .get_thread(child_thread_id) + .await + .expect("child thread should be registered"); + assert_ne!(child_thread_id, parent_thread_id); + let history = child_thread.codex.session.clone_history().await; + assert!(history_contains_text( + history.raw_items(), + "parent seed context" + )); + + let expected = ( + child_thread_id, + Op::UserInput { + items: vec![UserInput::Text { + text: "child task".to_string(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }, + ); + let captured = harness + .manager + .captured_ops() + .into_iter() + .find(|entry| *entry == expected); + assert_eq!(captured, Some(expected)); + + let _ = harness + .control + .shutdown_agent(child_thread_id) + .await + .expect("child shutdown should submit"); + let _ = parent_thread + .submit(Op::Shutdown {}) + .await + .expect("parent shutdown should submit"); + } + + #[tokio::test] + async fn spawn_agent_fork_injects_output_for_parent_spawn_call() { + let harness = AgentControlHarness::new().await; + let (parent_thread_id, parent_thread) = harness.start_thread().await; + let turn_context = parent_thread.codex.session.new_default_turn().await; + let parent_spawn_call_id = "spawn-call-1".to_string(); + let parent_spawn_call = ResponseItem::FunctionCall { + id: None, + name: "spawn_agent".to_string(), + arguments: "{}".to_string(), + call_id: parent_spawn_call_id.clone(), + }; + parent_thread + .codex + .session + .record_conversation_items(turn_context.as_ref(), &[parent_spawn_call]) + .await; + parent_thread + .codex + .session + .ensure_rollout_materialized() + .await; + parent_thread.codex.session.flush_rollout().await; + + let child_thread_id = harness + .control + .spawn_agent_with_options( + harness.config.clone(), + text_input("child task"), + Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, + depth: 1, + agent_nickname: None, + agent_role: None, + })), + SpawnAgentOptions { + fork_parent_spawn_call_id: Some(parent_spawn_call_id.clone()), + }, + ) + .await + .expect("forked spawn should succeed"); + + let child_thread = harness + .manager + .get_thread(child_thread_id) + .await + .expect("child thread should be registered"); + let history = child_thread.codex.session.clone_history().await; + let injected_output = history.raw_items().iter().find_map(|item| match item { + ResponseItem::FunctionCallOutput { call_id, output } + if call_id == &parent_spawn_call_id => + { + Some(output) + } + _ => None, + }); + let injected_output = + injected_output.expect("forked child should contain synthetic tool output"); + assert_eq!( + injected_output.text_content(), + Some(FORKED_SPAWN_AGENT_OUTPUT_MESSAGE) + ); + assert_eq!(injected_output.success, Some(true)); + + let _ = harness + .control + .shutdown_agent(child_thread_id) + .await + .expect("child shutdown should submit"); + let _ = parent_thread + .submit(Op::Shutdown {}) + .await + .expect("parent shutdown should submit"); + } + + #[tokio::test] + async fn spawn_agent_fork_flushes_parent_rollout_before_loading_history() { + let harness = AgentControlHarness::new().await; + let (parent_thread_id, parent_thread) = harness.start_thread().await; + let turn_context = parent_thread.codex.session.new_default_turn().await; + let parent_spawn_call_id = "spawn-call-unflushed".to_string(); + let parent_spawn_call = ResponseItem::FunctionCall { + id: None, + name: "spawn_agent".to_string(), + arguments: "{}".to_string(), + call_id: parent_spawn_call_id.clone(), + }; + parent_thread + .codex + .session + .record_conversation_items(turn_context.as_ref(), &[parent_spawn_call]) + .await; + + let child_thread_id = harness + .control + .spawn_agent_with_options( + harness.config.clone(), + text_input("child task"), + Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, + depth: 1, + agent_nickname: None, + agent_role: None, + })), + SpawnAgentOptions { + fork_parent_spawn_call_id: Some(parent_spawn_call_id.clone()), + }, + ) + .await + .expect("forked spawn should flush parent rollout before loading history"); + + let child_thread = harness + .manager + .get_thread(child_thread_id) + .await + .expect("child thread should be registered"); + let history = child_thread.codex.session.clone_history().await; + + let mut parent_call_index = None; + let mut injected_output_index = None; + for (idx, item) in history.raw_items().iter().enumerate() { + match item { + ResponseItem::FunctionCall { call_id, .. } if call_id == &parent_spawn_call_id => { + parent_call_index = Some(idx); + } + ResponseItem::FunctionCallOutput { call_id, .. } + if call_id == &parent_spawn_call_id => + { + injected_output_index = Some(idx); + } + _ => {} + } + } + + let parent_call_index = + parent_call_index.expect("forked child should include the parent spawn_agent call"); + let injected_output_index = injected_output_index + .expect("forked child should include synthetic output for the parent spawn_agent call"); + assert!(parent_call_index < injected_output_index); + + let _ = harness + .control + .shutdown_agent(child_thread_id) + .await + .expect("child shutdown should submit"); + let _ = parent_thread + .submit(Op::Shutdown {}) + .await + .expect("parent shutdown should submit"); + } + #[tokio::test] async fn spawn_agent_respects_max_threads_limit() { let max_threads = 1usize; diff --git a/codex-rs/core/src/agent/role.rs b/codex-rs/core/src/agent/role.rs index 81fd0b4d392..766497cda04 100644 --- a/codex-rs/core/src/agent/role.rs +++ b/codex-rs/core/src/agent/role.rs @@ -190,15 +190,17 @@ Rules: ( "awaiter".to_string(), AgentRoleConfig { - description: Some(r#"Use an `awaiter` agent EVERY TIME you must run a command that might take some very long time. + description: Some(r#"Use an `awaiter` agent EVERY TIME you must run a command that will take some very long time. This includes, but not only: * testing * monitoring of a long running process * explicit ask to wait for something -When YOU wait for the `awaiter` agent to be done, use the largest possible timeout. -Be patient with the `awaiter`. -Close the awaiter when you're done with it."#.to_string()), +Rules: +- When an awaiter is running, you can work on something else. If you need to wait for its completion, use the largest possible timeout. +- Be patient with the `awaiter`. +- Do not use an awaiter for every compilation/test if it won't take time. Only use if for long running commands. +- Close the awaiter when you're done with it."#.to_string()), config_file: Some("awaiter.toml".to_string().parse().unwrap_or_default()), } ) diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index 33fc438727e..3fea6eed56a 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -84,19 +84,13 @@ fn reserialize_shell_outputs(items: &mut [ResponseItem]) { shell_call_ids.insert(call_id.clone()); } } - ResponseItem::CustomToolCallOutput { call_id, output } => { - if shell_call_ids.remove(call_id) - && let Some(structured) = parse_structured_shell_output(output) - { - *output = structured - } - } ResponseItem::FunctionCall { name, call_id, .. } if is_shell_tool_name(name) || name == "apply_patch" => { shell_call_ids.insert(call_id.clone()); } - ResponseItem::FunctionCallOutput { call_id, output } => { + ResponseItem::FunctionCallOutput { call_id, output } + | ResponseItem::CustomToolCallOutput { call_id, output } => { if shell_call_ids.remove(call_id) && let Some(structured) = output .text_content() @@ -240,6 +234,7 @@ mod tests { use codex_api::common::OpenAiVerbosity; use codex_api::common::TextControls; use codex_api::create_text_param_for_request; + use codex_protocol::models::FunctionCallOutputPayload; use pretty_assertions::assert_eq; use super::*; @@ -343,4 +338,62 @@ mod tests { let v = serde_json::to_value(&req).expect("json"); assert!(v.get("text").is_none()); } + + #[test] + fn reserializes_shell_outputs_for_function_and_custom_tool_calls() { + let raw_output = r#"{"output":"hello","metadata":{"exit_code":0,"duration_seconds":0.5}}"#; + let expected_output = "Exit code: 0\nWall time: 0.5 seconds\nOutput:\nhello"; + let mut items = vec![ + ResponseItem::FunctionCall { + id: None, + name: "shell".to_string(), + arguments: "{}".to_string(), + call_id: "call-1".to_string(), + }, + ResponseItem::FunctionCallOutput { + call_id: "call-1".to_string(), + output: FunctionCallOutputPayload::from_text(raw_output.to_string()), + }, + ResponseItem::CustomToolCall { + id: None, + status: None, + call_id: "call-2".to_string(), + name: "apply_patch".to_string(), + input: "*** Begin Patch".to_string(), + }, + ResponseItem::CustomToolCallOutput { + call_id: "call-2".to_string(), + output: FunctionCallOutputPayload::from_text(raw_output.to_string()), + }, + ]; + + reserialize_shell_outputs(&mut items); + + assert_eq!( + items, + vec![ + ResponseItem::FunctionCall { + id: None, + name: "shell".to_string(), + arguments: "{}".to_string(), + call_id: "call-1".to_string(), + }, + ResponseItem::FunctionCallOutput { + call_id: "call-1".to_string(), + output: FunctionCallOutputPayload::from_text(expected_output.to_string()), + }, + ResponseItem::CustomToolCall { + id: None, + status: None, + call_id: "call-2".to_string(), + name: "apply_patch".to_string(), + input: "*** Begin Patch".to_string(), + }, + ResponseItem::CustomToolCallOutput { + call_id: "call-2".to_string(), + output: FunctionCallOutputPayload::from_text(expected_output.to_string()), + }, + ] + ); + } } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 68dfdbc1eb2..c511739f3a1 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -53,6 +53,8 @@ use crate::util::error_or_panic; use crate::ws_version_from_features; use async_channel::Receiver; use async_channel::Sender; +use chrono::Local; +use chrono::Utc; use codex_hooks::HookEvent; use codex_hooks::HookEventAfterAgent; use codex_hooks::HookPayload; @@ -441,6 +443,7 @@ impl Codex { thread_name: None, original_config_do_not_use: Arc::clone(&config), metrics_service_name, + app_server_client_name: None, session_source, dynamic_tools, persist_extended_history, @@ -528,6 +531,18 @@ impl Codex { self.session.steer_input(input, expected_turn_id).await } + pub(crate) async fn set_app_server_client_name( + &self, + app_server_client_name: Option, + ) -> ConstraintResult<()> { + self.session + .update_settings(SessionSettingsUpdate { + app_server_client_name, + ..Default::default() + }) + .await + } + pub(crate) async fn agent_status(&self) -> AgentStatus { self.agent_status.borrow().clone() } @@ -595,6 +610,9 @@ pub(crate) struct TurnContext { /// the model as well as sandbox policies are resolved against this path /// instead of `std::env::current_dir()`. pub(crate) cwd: PathBuf, + pub(crate) current_date: Option, + pub(crate) timezone: Option, + pub(crate) app_server_client_name: Option, pub(crate) developer_instructions: Option, pub(crate) compact_prompt: Option, pub(crate) user_instructions: Option, @@ -679,6 +697,9 @@ impl TurnContext { reasoning_summary: self.reasoning_summary, session_source: self.session_source.clone(), cwd: self.cwd.clone(), + current_date: self.current_date.clone(), + timezone: self.timezone.clone(), + app_server_client_name: self.app_server_client_name.clone(), developer_instructions: self.developer_instructions.clone(), compact_prompt: self.compact_prompt.clone(), user_instructions: self.user_instructions.clone(), @@ -719,6 +740,8 @@ impl TurnContext { TurnContextItem { turn_id: Some(self.sub_id.clone()), cwd: self.cwd.clone(), + current_date: self.current_date.clone(), + timezone: self.timezone.clone(), approval_policy: self.approval_policy.value(), sandbox_policy: self.sandbox_policy.get().clone(), network: self.turn_context_network_item(), @@ -748,13 +771,23 @@ impl TurnContext { } } +fn local_time_context() -> (String, String) { + match iana_time_zone::get_timezone() { + Ok(timezone) => (Local::now().format("%Y-%m-%d").to_string(), timezone), + Err(_) => ( + Utc::now().format("%Y-%m-%d").to_string(), + "Etc/UTC".to_string(), + ), + } +} + #[derive(Clone)] pub(crate) struct SessionConfiguration { /// Provider identifier ("openai", "openrouter", ...). provider: ModelProviderInfo, collaboration_mode: CollaborationMode, - model_reasoning_summary: ReasoningSummaryConfig, + model_reasoning_summary: Option, /// Developer instructions that supplement the base instructions. developer_instructions: Option, @@ -794,6 +827,7 @@ pub(crate) struct SessionConfiguration { original_config_do_not_use: Arc, /// Optional service name tag for session metrics. metrics_service_name: Option, + app_server_client_name: Option, /// Source of the session (cli, vscode, exec, mcp, ...) session_source: SessionSource, dynamic_tools: Vec, @@ -824,7 +858,7 @@ impl SessionConfiguration { next_configuration.collaboration_mode = collaboration_mode; } if let Some(summary) = updates.reasoning_summary { - next_configuration.model_reasoning_summary = summary; + next_configuration.model_reasoning_summary = Some(summary); } if let Some(personality) = updates.personality { next_configuration.personality = Some(personality); @@ -841,6 +875,9 @@ impl SessionConfiguration { if let Some(cwd) = updates.cwd.clone() { next_configuration.cwd = cwd; } + if let Some(app_server_client_name) = updates.app_server_client_name.clone() { + next_configuration.app_server_client_name = Some(app_server_client_name); + } Ok(next_configuration) } } @@ -855,6 +892,7 @@ pub(crate) struct SessionSettingsUpdate { pub(crate) reasoning_summary: Option, pub(crate) final_output_json_schema: Option>, pub(crate) personality: Option, + pub(crate) app_server_client_name: Option, } impl Session { @@ -985,7 +1023,9 @@ impl Session { skills_outcome: Arc, ) -> TurnContext { let reasoning_effort = session_configuration.collaboration_mode.reasoning_effort(); - let reasoning_summary = session_configuration.model_reasoning_summary; + let reasoning_summary = session_configuration + .model_reasoning_summary + .unwrap_or(model_info.default_reasoning_summary); let otel_manager = otel_manager.clone().with_model( session_configuration.collaboration_mode.model(), model_info.slug.as_str(), @@ -1015,6 +1055,7 @@ impl Session { .features .enabled(Feature::UseLinuxSandboxBwrap), )); + let (current_date, timezone) = local_time_context(); TurnContext { sub_id, config: per_turn_config.clone(), @@ -1026,6 +1067,9 @@ impl Session { reasoning_summary, session_source, cwd, + current_date: Some(current_date), + timezone: Some(timezone), + app_server_client_name: session_configuration.app_server_client_name.clone(), developer_instructions: session_configuration.developer_instructions.clone(), compact_prompt: session_configuration.compact_prompt.clone(), user_instructions: session_configuration.user_instructions.clone(), @@ -1124,7 +1168,7 @@ impl Session { // // - initialize RolloutRecorder with new or resumed session info // - perform default shell discovery - // - load history metadata + // - load history metadata (skipped for subagents) let rollout_fut = async { if config.ephemeral { Ok::<_, anyhow::Error>((None, None)) @@ -1141,7 +1185,16 @@ impl Session { } }; - let history_meta_fut = crate::message_history::history_metadata(&config); + let history_meta_fut = async { + if matches!( + session_configuration.session_source, + SessionSource::SubAgent(_) + ) { + (0, 0) + } else { + crate::message_history::history_metadata(&config).await + } + }; let auth_manager_clone = Arc::clone(&auth_manager); let config_for_mcp = Arc::clone(&config); let auth_and_mcp_fut = async move { @@ -1262,7 +1315,9 @@ impl Session { otel_manager.conversation_starts( config.model_provider.name.as_str(), session_configuration.collaboration_mode.reasoning_effort(), - config.model_reasoning_summary, + config + .model_reasoning_summary + .unwrap_or(ReasoningSummaryConfig::Auto), config.model_context_window, config.model_auto_compact_token_limit, config.permissions.approval_policy.value(), @@ -3077,8 +3132,15 @@ impl Session { .serialize_to_text(), ); } + let subagents = self + .services + .agent_control + .format_environment_context_subagents(self.conversation_id) + .await; contextual_user_sections.push( - EnvironmentContext::from_turn_context(turn_context, shell.as_ref()).serialize_to_xml(), + EnvironmentContext::from_turn_context(turn_context, shell.as_ref()) + .with_subagents(subagents) + .serialize_to_xml(), ); let mut items = Vec::with_capacity(2); @@ -3904,9 +3966,10 @@ mod handlers { sandbox_policy: Some(sandbox_policy), windows_sandbox_level: None, collaboration_mode, - reasoning_summary: Some(summary), + reasoning_summary: summary, final_output_json_schema: Some(final_output_json_schema), personality, + app_server_client_name: None, }, ) } @@ -4626,7 +4689,9 @@ async fn spawn_review_thread( let provider_for_context = provider.clone(); let otel_manager_for_context = otel_manager.clone(); let reasoning_effort = per_turn_config.model_reasoning_effort; - let reasoning_summary = per_turn_config.model_reasoning_summary; + let reasoning_summary = per_turn_config + .model_reasoning_summary + .unwrap_or(model_info.default_reasoning_summary); let session_source = parent_turn_context.session_source.clone(); let per_turn_config = Arc::new(per_turn_config); @@ -4654,6 +4719,9 @@ async fn spawn_review_thread( tools_config, features: parent_turn_context.features.clone(), ghost_snapshot: parent_turn_context.ghost_snapshot.clone(), + current_date: parent_turn_context.current_date.clone(), + timezone: parent_turn_context.timezone.clone(), + app_server_client_name: parent_turn_context.app_server_client_name.clone(), developer_instructions: None, user_instructions: None, compact_prompt: parent_turn_context.compact_prompt.clone(), @@ -5039,6 +5107,7 @@ pub(crate) async fn run_turn( .dispatch(HookPayload { session_id: sess.conversation_id, cwd: turn_context.cwd.clone(), + client: turn_context.app_server_client_name.clone(), triggered_at: chrono::Utc::now(), hook_event: HookEvent::AfterAgent { event: HookEventAfterAgent { @@ -7237,6 +7306,8 @@ mod tests { let previous_context_item = TurnContextItem { turn_id: Some(turn_context.sub_id.clone()), cwd: turn_context.cwd.clone(), + current_date: turn_context.current_date.clone(), + timezone: turn_context.timezone.clone(), approval_policy: turn_context.approval_policy.value(), sandbox_policy: turn_context.sandbox_policy.get().clone(), network: None, @@ -7274,6 +7345,8 @@ mod tests { let mut previous_context_item = TurnContextItem { turn_id: Some(turn_context.sub_id.clone()), cwd: turn_context.cwd.clone(), + current_date: turn_context.current_date.clone(), + timezone: turn_context.timezone.clone(), approval_policy: turn_context.approval_policy.value(), sandbox_policy: turn_context.sandbox_policy.get().clone(), network: None, @@ -7468,7 +7541,10 @@ mod tests { .record_context_updates_and_set_reference_context_item(&turn_context, None) .await; let history_after_second_seed = session.clone_history().await; - assert_eq!(expected, history_after_second_seed.raw_items()); + assert_eq!( + history_after_seed.raw_items(), + history_after_second_seed.raw_items() + ); } #[tokio::test] @@ -7636,6 +7712,8 @@ mod tests { let previous_context_item = TurnContextItem { turn_id: Some(turn_context.sub_id.clone()), cwd: turn_context.cwd.clone(), + current_date: turn_context.current_date.clone(), + timezone: turn_context.timezone.clone(), approval_policy: turn_context.approval_policy.value(), sandbox_policy: turn_context.sandbox_policy.get().clone(), network: None, @@ -7839,6 +7917,7 @@ mod tests { thread_name: None, original_config_do_not_use: Arc::clone(&config), metrics_service_name: None, + app_server_client_name: None, session_source: SessionSource::Exec, dynamic_tools: Vec::new(), persist_extended_history: false, @@ -7931,6 +8010,7 @@ mod tests { thread_name: None, original_config_do_not_use: Arc::clone(&config), metrics_service_name: None, + app_server_client_name: None, session_source: SessionSource::Exec, dynamic_tools: Vec::new(), persist_extended_history: false, @@ -8242,6 +8322,7 @@ mod tests { thread_name: None, original_config_do_not_use: Arc::clone(&config), metrics_service_name: None, + app_server_client_name: None, session_source: SessionSource::Exec, dynamic_tools: Vec::new(), persist_extended_history: false, @@ -8295,6 +8376,7 @@ mod tests { thread_name: None, original_config_do_not_use: Arc::clone(&config), metrics_service_name: None, + app_server_client_name: None, session_source: SessionSource::Exec, dynamic_tools: Vec::new(), persist_extended_history: false, @@ -8376,6 +8458,7 @@ mod tests { thread_name: None, original_config_do_not_use: Arc::clone(&config), metrics_service_name: None, + app_server_client_name: None, session_source: SessionSource::Exec, dynamic_tools: Vec::new(), persist_extended_history: false, @@ -8535,6 +8618,7 @@ mod tests { thread_name: None, original_config_do_not_use: Arc::clone(&config), metrics_service_name: None, + app_server_client_name: None, session_source: SessionSource::Exec, dynamic_tools, persist_extended_history: false, @@ -8801,6 +8885,42 @@ mod tests { assert!(environment_update.contains("blocked.example.com")); } + #[tokio::test] + async fn build_settings_update_items_emits_environment_item_for_time_changes() { + let (session, previous_context) = make_session_and_context().await; + let previous_context = Arc::new(previous_context); + let mut current_context = previous_context + .with_model( + previous_context.model_info.slug.clone(), + &session.services.models_manager, + ) + .await; + current_context.current_date = Some("2026-02-27".to_string()); + current_context.timezone = Some("Europe/Berlin".to_string()); + + let reference_context_item = previous_context.to_turn_context_item(); + let update_items = session.build_settings_update_items( + Some(&reference_context_item), + None, + ¤t_context, + ); + + let environment_update = update_items + .iter() + .find_map(|item| match item { + ResponseItem::Message { role, content, .. } if role == "user" => { + let [ContentItem::InputText { text }] = content.as_slice() else { + return None; + }; + text.contains("").then_some(text) + } + _ => None, + }) + .expect("environment update item should be emitted"); + assert!(environment_update.contains("2026-02-27")); + assert!(environment_update.contains("Europe/Berlin")); + } + #[tokio::test] async fn record_context_updates_and_set_reference_context_item_injects_full_context_when_baseline_missing() { diff --git a/codex-rs/core/src/codex_thread.rs b/codex-rs/core/src/codex_thread.rs index b493075d4a3..19a8214ee08 100644 --- a/codex-rs/core/src/codex_thread.rs +++ b/codex-rs/core/src/codex_thread.rs @@ -1,6 +1,7 @@ use crate::agent::AgentStatus; use crate::codex::Codex; use crate::codex::SteerInputError; +use crate::config::ConstraintResult; use crate::error::Result as CodexResult; use crate::features::Feature; use crate::file_watcher::WatchRegistration; @@ -67,6 +68,15 @@ impl CodexThread { self.codex.steer_input(input, expected_turn_id).await } + pub async fn set_app_server_client_name( + &self, + app_server_client_name: Option, + ) -> ConstraintResult<()> { + self.codex + .set_app_server_client_name(app_server_client_name) + .await + } + /// Use sparingly: this is intended to be removed soon. pub async fn submit_with_id(&self, sub: Submission) -> CodexResult<()> { self.codex.submit_with_id(sub).await diff --git a/codex-rs/core/src/compact_remote.rs b/codex-rs/core/src/compact_remote.rs index c019a58ce52..cc5f5164c39 100644 --- a/codex-rs/core/src/compact_remote.rs +++ b/codex-rs/core/src/compact_remote.rs @@ -105,7 +105,6 @@ async fn run_remote_compact_task_inner_impl( "trimmed history items before remote compaction" ); } - // Required to keep `/undo` available after compaction let ghost_snapshots: Vec = history .raw_items() diff --git a/codex-rs/core/src/config/edit.rs b/codex-rs/core/src/config/edit.rs index fceb9659989..592f50d9075 100644 --- a/codex-rs/core/src/config/edit.rs +++ b/codex-rs/core/src/config/edit.rs @@ -195,6 +195,11 @@ mod document_helpers { { entry["scopes"] = array_from_iter(scopes.iter().cloned()); } + if let Some(resource) = &config.oauth_resource + && !resource.is_empty() + { + entry["oauth_resource"] = value(resource.clone()); + } entry } @@ -839,6 +844,30 @@ impl ConfigEditsBuilder { self } + pub fn set_realtime_microphone(mut self, microphone: Option<&str>) -> Self { + let segments = vec!["audio".to_string(), "microphone".to_string()]; + match microphone { + Some(microphone) => self.edits.push(ConfigEdit::SetPath { + segments, + value: value(microphone), + }), + None => self.edits.push(ConfigEdit::ClearPath { segments }), + } + self + } + + pub fn set_realtime_speaker(mut self, speaker: Option<&str>) -> Self { + let segments = vec!["audio".to_string(), "speaker".to_string()]; + match speaker { + Some(speaker) => self.edits.push(ConfigEdit::SetPath { + segments, + value: value(speaker), + }), + None => self.edits.push(ConfigEdit::ClearPath { segments }), + } + self + } + pub fn clear_legacy_windows_sandbox_keys(mut self) -> Self { for key in [ "experimental_windows_sandbox", @@ -1441,6 +1470,7 @@ gpt-5 = "gpt-5.1" enabled_tools: Some(vec!["one".to_string(), "two".to_string()]), disabled_tools: None, scopes: None, + oauth_resource: None, }, ); @@ -1465,6 +1495,7 @@ gpt-5 = "gpt-5.1" enabled_tools: None, disabled_tools: Some(vec!["forbidden".to_string()]), scopes: None, + oauth_resource: Some("https://resource.example.com".to_string()), }, ); @@ -1483,6 +1514,7 @@ bearer_token_env_var = \"TOKEN\" enabled = false startup_timeout_sec = 5.0 disabled_tools = [\"forbidden\"] +oauth_resource = \"https://resource.example.com\" [mcp_servers.http.http_headers] Z-Header = \"z\" @@ -1532,6 +1564,7 @@ foo = { command = "cmd" } enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, ); @@ -1578,6 +1611,7 @@ foo = { command = "cmd" } # keep me enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, ); @@ -1623,6 +1657,7 @@ foo = { command = "cmd", args = ["--flag"] } # keep me enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, ); @@ -1669,6 +1704,7 @@ foo = { command = "cmd" } enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, ); @@ -1804,6 +1840,50 @@ model_reasoning_effort = "high" assert_eq!(notice, Some(true)); } + #[test] + fn blocking_builder_set_realtime_audio_persists_and_clears() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + + ConfigEditsBuilder::new(codex_home) + .set_realtime_microphone(Some("USB Mic")) + .set_realtime_speaker(Some("Desk Speakers")) + .apply_blocking() + .expect("persist realtime audio"); + + let raw = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let config: TomlValue = toml::from_str(&raw).expect("parse config"); + let realtime_audio = config + .get("audio") + .and_then(TomlValue::as_table) + .expect("audio table should exist"); + assert_eq!( + realtime_audio.get("microphone").and_then(TomlValue::as_str), + Some("USB Mic") + ); + assert_eq!( + realtime_audio.get("speaker").and_then(TomlValue::as_str), + Some("Desk Speakers") + ); + + ConfigEditsBuilder::new(codex_home) + .set_realtime_microphone(None) + .apply_blocking() + .expect("clear realtime microphone"); + + let raw = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let config: TomlValue = toml::from_str(&raw).expect("parse config"); + let realtime_audio = config + .get("audio") + .and_then(TomlValue::as_table) + .expect("audio table should exist"); + assert_eq!(realtime_audio.get("microphone"), None); + assert_eq!( + realtime_audio.get("speaker").and_then(TomlValue::as_str), + Some("Desk Speakers") + ); + } + #[test] fn replace_mcp_servers_blocking_clears_table_when_empty() { let tmp = tempdir().expect("tmpdir"); diff --git a/codex-rs/core/src/config/mod.rs b/codex-rs/core/src/config/mod.rs index 306ba49c100..b3661fe790b 100644 --- a/codex-rs/core/src/config/mod.rs +++ b/codex-rs/core/src/config/mod.rs @@ -49,8 +49,6 @@ use crate::project_doc::LOCAL_PROJECT_DOC_FILENAME; use crate::protocol::AskForApproval; use crate::protocol::ReadOnlyAccess; use crate::protocol::SandboxPolicy; -#[cfg(target_os = "macos")] -use crate::seatbelt_permissions::MacOsSeatbeltProfileExtensions; use crate::unified_exec::DEFAULT_MAX_BACKGROUND_TERMINAL_TIMEOUT_MS; use crate::unified_exec::MIN_EMPTY_YIELD_TIME_MS; use crate::windows_sandbox::WindowsSandboxLevelExt; @@ -66,6 +64,7 @@ use codex_protocol::config_types::TrustLevel; use codex_protocol::config_types::Verbosity; use codex_protocol::config_types::WebSearchMode; use codex_protocol::config_types::WindowsSandboxLevel; +use codex_protocol::models::MacOsSeatbeltProfileExtensions; use codex_protocol::openai_models::ModelsResponse; use codex_protocol::openai_models::ReasoningEffort; use codex_rmcp_client::OAuthCredentialsStoreMode; @@ -82,8 +81,6 @@ use std::path::Path; use std::path::PathBuf; #[cfg(test)] use tempfile::tempdir; -#[cfg(not(target_os = "macos"))] -type MacOsSeatbeltProfileExtensions = (); use crate::config::permissions::network_proxy_config_from_permissions; use crate::config::profile::ConfigProfile; @@ -412,9 +409,9 @@ pub struct Config { /// global default"). pub plan_mode_reasoning_effort: Option, - /// If not "none", the value to use for `reasoning.summary` when making a - /// request using the Responses API. - pub model_reasoning_summary: ReasoningSummary, + /// Optional value to use for `reasoning.summary` when making a request + /// using the Responses API. When unset, the model catalog default is used. + pub model_reasoning_summary: Option, /// Optional override to force-enable reasoning summaries for the configured model. pub model_supports_reasoning_summaries: Option, @@ -429,6 +426,9 @@ pub struct Config { /// Base URL for requests to ChatGPT (as opposed to the OpenAI API). pub chatgpt_base_url: String, + /// Machine-local realtime audio device preferences used by realtime voice. + pub realtime_audio: RealtimeAudioConfig, + /// Experimental / do not use. Overrides only the realtime conversation /// websocket transport base URL (the `Op::RealtimeConversation` `/ws` /// connection) without changing normal provider HTTP requests. @@ -1178,6 +1178,10 @@ pub struct ConfigToml { /// Base URL for requests to ChatGPT (as opposed to the OpenAI API). pub chatgpt_base_url: Option, + /// Machine-local realtime audio device preferences used by realtime voice. + #[serde(default)] + pub audio: Option, + /// Experimental / do not use. Overrides only the realtime conversation /// websocket transport base URL (the `Op::RealtimeConversation` `/ws` /// connection) without changing normal provider HTTP requests. @@ -1309,6 +1313,19 @@ impl ProjectConfig { } } +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct RealtimeAudioConfig { + pub microphone: Option, + pub speaker: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq, Eq, JsonSchema)] +#[schemars(deny_unknown_fields)] +pub struct RealtimeAudioToml { + pub microphone: Option, + pub speaker: Option, +} + #[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq, JsonSchema)] #[schemars(deny_unknown_fields)] pub struct ToolsToml { @@ -2141,8 +2158,7 @@ impl Config { .or(cfg.plan_mode_reasoning_effort), model_reasoning_summary: config_profile .model_reasoning_summary - .or(cfg.model_reasoning_summary) - .unwrap_or_default(), + .or(cfg.model_reasoning_summary), model_supports_reasoning_summaries: cfg.model_supports_reasoning_summaries, model_catalog, model_verbosity: config_profile.model_verbosity.or(cfg.model_verbosity), @@ -2150,6 +2166,12 @@ impl Config { .chatgpt_base_url .or(cfg.chatgpt_base_url) .unwrap_or("https://chatgpt.com/backend-api/".to_string()), + realtime_audio: cfg + .audio + .map_or_else(RealtimeAudioConfig::default, |audio| RealtimeAudioConfig { + microphone: audio.microphone, + speaker: audio.speaker, + }), experimental_realtime_ws_base_url: cfg.experimental_realtime_ws_base_url, experimental_realtime_ws_backend_prompt: cfg.experimental_realtime_ws_backend_prompt, forced_chatgpt_workspace_id, @@ -2411,6 +2433,7 @@ mod tests { enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, } } @@ -2430,6 +2453,7 @@ mod tests { enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, } } @@ -2467,6 +2491,7 @@ persistence = "none" let memories = r#" [memories] max_raw_memories_for_global = 512 +max_unused_days = 21 max_rollout_age_days = 42 max_rollouts_per_startup = 9 min_rollout_idle_hours = 24 @@ -2478,6 +2503,7 @@ phase_2_model = "gpt-5" assert_eq!( Some(MemoriesToml { max_raw_memories_for_global: Some(512), + max_unused_days: Some(21), max_rollout_age_days: Some(42), max_rollouts_per_startup: Some(9), min_rollout_idle_hours: Some(24), @@ -2497,6 +2523,7 @@ phase_2_model = "gpt-5" config.memories, MemoriesConfig { max_raw_memories_for_global: 512, + max_unused_days: 21, max_rollout_age_days: 42, max_rollouts_per_startup: 9, min_rollout_idle_hours: 24, @@ -3464,6 +3491,7 @@ profile = "project" enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, ); @@ -3621,6 +3649,7 @@ bearer_token = "secret" enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, )]); @@ -3692,6 +3721,7 @@ ZIG_VAR = "3" enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, )]); @@ -3743,6 +3773,7 @@ ZIG_VAR = "3" enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, )]); @@ -3792,6 +3823,7 @@ ZIG_VAR = "3" enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, )]); @@ -3857,6 +3889,7 @@ startup_timeout_sec = 2.0 enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, )]); apply_blocking( @@ -3934,6 +3967,7 @@ X-Auth = "DOCS_AUTH" enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, )]); @@ -3964,6 +3998,7 @@ X-Auth = "DOCS_AUTH" enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, ); apply_blocking( @@ -4032,6 +4067,7 @@ url = "https://example.com/mcp" enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, ), ( @@ -4052,6 +4088,7 @@ url = "https://example.com/mcp" enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, ), ]); @@ -4135,6 +4172,7 @@ url = "https://example.com/mcp" enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, )]); @@ -4180,6 +4218,7 @@ url = "https://example.com/mcp" enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, )]); @@ -4225,6 +4264,7 @@ url = "https://example.com/mcp" enabled_tools: Some(vec!["allowed".to_string()]), disabled_tools: Some(vec!["blocked".to_string()]), scopes: None, + oauth_resource: None, }, )]); @@ -4253,6 +4293,51 @@ url = "https://example.com/mcp" Ok(()) } + #[tokio::test] + async fn replace_mcp_servers_streamable_http_serializes_oauth_resource() -> anyhow::Result<()> { + let codex_home = TempDir::new()?; + + let servers = BTreeMap::from([( + "docs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, + }, + enabled: true, + required: false, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + oauth_resource: Some("https://resource.example.com".to_string()), + }, + )]); + + apply_blocking( + codex_home.path(), + None, + &[ConfigEdit::ReplaceMcpServers(servers.clone())], + )?; + + let config_path = codex_home.path().join(CONFIG_TOML_FILE); + let serialized = std::fs::read_to_string(&config_path)?; + assert!(serialized.contains(r#"oauth_resource = "https://resource.example.com""#)); + + let loaded = load_global_mcp_servers(codex_home.path()).await?; + let docs = loaded.get("docs").expect("docs entry"); + assert_eq!( + docs.oauth_resource.as_deref(), + Some("https://resource.example.com") + ); + + Ok(()) + } + #[tokio::test] async fn set_model_updates_defaults() -> anyhow::Result<()> { let codex_home = TempDir::new()?; @@ -4761,12 +4846,13 @@ model_verbosity = "high" show_raw_agent_reasoning: false, model_reasoning_effort: Some(ReasoningEffort::High), plan_mode_reasoning_effort: None, - model_reasoning_summary: ReasoningSummary::Detailed, + model_reasoning_summary: Some(ReasoningSummary::Detailed), model_supports_reasoning_summaries: None, model_catalog: None, model_verbosity: None, personality: Some(Personality::Pragmatic), chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(), + realtime_audio: RealtimeAudioConfig::default(), experimental_realtime_ws_base_url: None, experimental_realtime_ws_backend_prompt: None, base_instructions: None, @@ -4887,12 +4973,13 @@ model_verbosity = "high" show_raw_agent_reasoning: false, model_reasoning_effort: None, plan_mode_reasoning_effort: None, - model_reasoning_summary: ReasoningSummary::default(), + model_reasoning_summary: None, model_supports_reasoning_summaries: None, model_catalog: None, model_verbosity: None, personality: Some(Personality::Pragmatic), chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(), + realtime_audio: RealtimeAudioConfig::default(), experimental_realtime_ws_base_url: None, experimental_realtime_ws_backend_prompt: None, base_instructions: None, @@ -5011,12 +5098,13 @@ model_verbosity = "high" show_raw_agent_reasoning: false, model_reasoning_effort: None, plan_mode_reasoning_effort: None, - model_reasoning_summary: ReasoningSummary::default(), + model_reasoning_summary: None, model_supports_reasoning_summaries: None, model_catalog: None, model_verbosity: None, personality: Some(Personality::Pragmatic), chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(), + realtime_audio: RealtimeAudioConfig::default(), experimental_realtime_ws_base_url: None, experimental_realtime_ws_backend_prompt: None, base_instructions: None, @@ -5121,12 +5209,13 @@ model_verbosity = "high" show_raw_agent_reasoning: false, model_reasoning_effort: Some(ReasoningEffort::High), plan_mode_reasoning_effort: None, - model_reasoning_summary: ReasoningSummary::Detailed, + model_reasoning_summary: Some(ReasoningSummary::Detailed), model_supports_reasoning_summaries: None, model_catalog: None, model_verbosity: Some(Verbosity::High), personality: Some(Personality::Pragmatic), chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(), + realtime_audio: RealtimeAudioConfig::default(), experimental_realtime_ws_base_url: None, experimental_realtime_ws_backend_prompt: None, base_instructions: None, @@ -5971,6 +6060,39 @@ experimental_realtime_ws_backend_prompt = "prompt from config" ); Ok(()) } + + #[test] + fn realtime_audio_loads_from_config_toml() -> std::io::Result<()> { + let cfg: ConfigToml = toml::from_str( + r#" +[audio] +microphone = "USB Mic" +speaker = "Desk Speakers" +"#, + ) + .expect("TOML deserialization should succeed"); + + let realtime_audio = cfg + .audio + .as_ref() + .expect("realtime audio config should be present"); + assert_eq!(realtime_audio.microphone.as_deref(), Some("USB Mic")); + assert_eq!(realtime_audio.speaker.as_deref(), Some("Desk Speakers")); + + let codex_home = TempDir::new()?; + let config = Config::load_from_base_config_with_overrides( + cfg, + ConfigOverrides::default(), + codex_home.path().to_path_buf(), + )?; + + assert_eq!(config.realtime_audio.microphone.as_deref(), Some("USB Mic")); + assert_eq!( + config.realtime_audio.speaker.as_deref(), + Some("Desk Speakers") + ); + Ok(()) + } } #[cfg(test)] diff --git a/codex-rs/core/src/config/types.rs b/codex-rs/core/src/config/types.rs index 678766e8078..ec0bf15320b 100644 --- a/codex-rs/core/src/config/types.rs +++ b/codex-rs/core/src/config/types.rs @@ -27,6 +27,7 @@ pub const DEFAULT_MEMORIES_MAX_ROLLOUTS_PER_STARTUP: usize = 16; pub const DEFAULT_MEMORIES_MAX_ROLLOUT_AGE_DAYS: i64 = 30; pub const DEFAULT_MEMORIES_MIN_ROLLOUT_IDLE_HOURS: i64 = 6; pub const DEFAULT_MEMORIES_MAX_RAW_MEMORIES_FOR_GLOBAL: usize = 1_024; +pub const DEFAULT_MEMORIES_MAX_UNUSED_DAYS: i64 = 30; #[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, JsonSchema)] #[serde(rename_all = "kebab-case")] @@ -98,6 +99,10 @@ pub struct McpServerConfig { /// Optional OAuth scopes to request during MCP login. #[serde(default, skip_serializing_if = "Option::is_none")] pub scopes: Option>, + + /// Optional OAuth resource parameter to include during MCP login (RFC 8707). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub oauth_resource: Option, } // Raw MCP config shape used for deserialization and JSON Schema generation. @@ -142,6 +147,8 @@ pub(crate) struct RawMcpServerConfig { pub disabled_tools: Option>, #[serde(default)] pub scopes: Option>, + #[serde(default)] + pub oauth_resource: Option, } impl<'de> Deserialize<'de> for McpServerConfig { @@ -165,6 +172,7 @@ impl<'de> Deserialize<'de> for McpServerConfig { let enabled_tools = raw.enabled_tools.clone(); let disabled_tools = raw.disabled_tools.clone(); let scopes = raw.scopes.clone(); + let oauth_resource = raw.oauth_resource.clone(); fn throw_if_set(transport: &str, field: &str, value: Option<&T>) -> Result<(), E> where @@ -188,6 +196,7 @@ impl<'de> Deserialize<'de> for McpServerConfig { throw_if_set("stdio", "bearer_token", raw.bearer_token.as_ref())?; throw_if_set("stdio", "http_headers", raw.http_headers.as_ref())?; throw_if_set("stdio", "env_http_headers", raw.env_http_headers.as_ref())?; + throw_if_set("stdio", "oauth_resource", raw.oauth_resource.as_ref())?; McpServerTransportConfig::Stdio { command, args: raw.args.clone().unwrap_or_default(), @@ -221,6 +230,7 @@ impl<'de> Deserialize<'de> for McpServerConfig { enabled_tools, disabled_tools, scopes, + oauth_resource, }) } } @@ -363,6 +373,8 @@ pub struct FeedbackConfigToml { pub struct MemoriesToml { /// Maximum number of recent raw memories retained for global consolidation. pub max_raw_memories_for_global: Option, + /// Maximum number of days since a memory was last used before it becomes ineligible for phase 2 selection. + pub max_unused_days: Option, /// Maximum age of the threads used for memories. pub max_rollout_age_days: Option, /// Maximum number of rollout candidates processed per pass. @@ -379,6 +391,7 @@ pub struct MemoriesToml { #[derive(Debug, Clone, PartialEq, Eq)] pub struct MemoriesConfig { pub max_raw_memories_for_global: usize, + pub max_unused_days: i64, pub max_rollout_age_days: i64, pub max_rollouts_per_startup: usize, pub min_rollout_idle_hours: i64, @@ -390,6 +403,7 @@ impl Default for MemoriesConfig { fn default() -> Self { Self { max_raw_memories_for_global: DEFAULT_MEMORIES_MAX_RAW_MEMORIES_FOR_GLOBAL, + max_unused_days: DEFAULT_MEMORIES_MAX_UNUSED_DAYS, max_rollout_age_days: DEFAULT_MEMORIES_MAX_ROLLOUT_AGE_DAYS, max_rollouts_per_startup: DEFAULT_MEMORIES_MAX_ROLLOUTS_PER_STARTUP, min_rollout_idle_hours: DEFAULT_MEMORIES_MIN_ROLLOUT_IDLE_HOURS, @@ -407,6 +421,10 @@ impl From for MemoriesConfig { .max_raw_memories_for_global .unwrap_or(defaults.max_raw_memories_for_global) .min(4096), + max_unused_days: toml + .max_unused_days + .unwrap_or(defaults.max_unused_days) + .clamp(0, 365), max_rollout_age_days: toml .max_rollout_age_days .unwrap_or(defaults.max_rollout_age_days) @@ -1084,6 +1102,22 @@ mod tests { ); } + #[test] + fn deserialize_streamable_http_server_config_with_oauth_resource() { + let cfg: McpServerConfig = toml::from_str( + r#" + url = "https://example.com/mcp" + oauth_resource = "https://api.example.com" + "#, + ) + .expect("should deserialize http config with oauth_resource"); + + assert_eq!( + cfg.oauth_resource, + Some("https://api.example.com".to_string()) + ); + } + #[test] fn deserialize_server_config_with_tool_filters() { let cfg: McpServerConfig = toml::from_str( @@ -1138,6 +1172,20 @@ mod tests { "#, ) .expect_err("should reject env_http_headers for stdio transport"); + + let err = toml::from_str::( + r#" + command = "echo" + oauth_resource = "https://api.example.com" + "#, + ) + .expect_err("should reject oauth_resource for stdio transport"); + + assert!( + err.to_string() + .contains("oauth_resource is not supported for stdio"), + "unexpected error: {err}" + ); } #[test] diff --git a/codex-rs/core/src/context_manager/history.rs b/codex-rs/core/src/context_manager/history.rs index 016642b3314..e4b7755abf4 100644 --- a/codex-rs/core/src/context_manager/history.rs +++ b/codex-rs/core/src/context_manager/history.rs @@ -344,32 +344,21 @@ impl ContextManager { let policy_with_serialization_budget = policy * 1.2; match item { ResponseItem::FunctionCallOutput { call_id, output } => { - let body = match &output.body { - FunctionCallOutputBody::Text(content) => FunctionCallOutputBody::Text( - truncate_text(content, policy_with_serialization_budget), - ), - FunctionCallOutputBody::ContentItems(items) => { - FunctionCallOutputBody::ContentItems( - truncate_function_output_items_with_policy( - items, - policy_with_serialization_budget, - ), - ) - } - }; ResponseItem::FunctionCallOutput { call_id: call_id.clone(), - output: FunctionCallOutputPayload { - body, - success: output.success, - }, + output: truncate_function_output_payload( + output, + policy_with_serialization_budget, + ), } } ResponseItem::CustomToolCallOutput { call_id, output } => { - let truncated = truncate_text(output, policy_with_serialization_budget); ResponseItem::CustomToolCallOutput { call_id: call_id.clone(), - output: truncated, + output: truncate_function_output_payload( + output, + policy_with_serialization_budget, + ), } } ResponseItem::Message { .. } @@ -385,6 +374,25 @@ impl ContextManager { } } +fn truncate_function_output_payload( + output: &FunctionCallOutputPayload, + policy: TruncationPolicy, +) -> FunctionCallOutputPayload { + let body = match &output.body { + FunctionCallOutputBody::Text(content) => { + FunctionCallOutputBody::Text(truncate_text(content, policy)) + } + FunctionCallOutputBody::ContentItems(items) => FunctionCallOutputBody::ContentItems( + truncate_function_output_items_with_policy(items, policy), + ), + }; + + FunctionCallOutputPayload { + body, + success: output.success, + } +} + /// API messages include every non-system item (user/assistant messages, reasoning, /// tool calls, tool outputs, shell calls, and web-search calls). fn is_api_message(message: &ResponseItem) -> bool { @@ -508,7 +516,8 @@ fn image_data_url_estimate_adjustment(item: &ResponseItem) -> (i64, i64) { } } } - ResponseItem::FunctionCallOutput { output, .. } => { + ResponseItem::FunctionCallOutput { output, .. } + | ResponseItem::CustomToolCallOutput { output, .. } => { if let FunctionCallOutputBody::ContentItems(items) = &output.body { for content_item in items { if let FunctionCallOutputContentItem::InputImage { image_url } = content_item { diff --git a/codex-rs/core/src/context_manager/history_tests.rs b/codex-rs/core/src/context_manager/history_tests.rs index 52fff81ed01..798abc76751 100644 --- a/codex-rs/core/src/context_manager/history_tests.rs +++ b/codex-rs/core/src/context_manager/history_tests.rs @@ -67,7 +67,7 @@ fn user_input_text_msg(text: &str) -> ResponseItem { fn custom_tool_call_output(call_id: &str, output: &str) -> ResponseItem { ResponseItem::CustomToolCallOutput { call_id: call_id.to_string(), - output: output.to_string(), + output: FunctionCallOutputPayload::from_text(output.to_string()), } } @@ -279,6 +279,24 @@ fn for_prompt_strips_images_when_model_does_not_support_images() { }, ]), }, + ResponseItem::CustomToolCall { + id: None, + status: None, + call_id: "tool-1".to_string(), + name: "js_repl".to_string(), + input: "view_image".to_string(), + }, + ResponseItem::CustomToolCallOutput { + call_id: "tool-1".to_string(), + output: FunctionCallOutputPayload::from_content_items(vec![ + FunctionCallOutputContentItem::InputText { + text: "js repl result".to_string(), + }, + FunctionCallOutputContentItem::InputImage { + image_url: "https://example.com/js-repl-result.png".to_string(), + }, + ]), + }, ]; let history = create_history_with_items(items); let text_only_modalities = vec![InputModality::Text]; @@ -321,6 +339,25 @@ fn for_prompt_strips_images_when_model_does_not_support_images() { }, ]), }, + ResponseItem::CustomToolCall { + id: None, + status: None, + call_id: "tool-1".to_string(), + name: "js_repl".to_string(), + input: "view_image".to_string(), + }, + ResponseItem::CustomToolCallOutput { + call_id: "tool-1".to_string(), + output: FunctionCallOutputPayload::from_content_items(vec![ + FunctionCallOutputContentItem::InputText { + text: "js repl result".to_string(), + }, + FunctionCallOutputContentItem::InputText { + text: "image content omitted because you do not support image input" + .to_string(), + }, + ]), + }, ]; assert_eq!(stripped, expected); @@ -671,7 +708,7 @@ fn remove_first_item_handles_custom_tool_pair() { }, ResponseItem::CustomToolCallOutput { call_id: "tool-1".to_string(), - output: "ok".to_string(), + output: FunctionCallOutputPayload::from_text("ok".to_string()), }, ]; let mut h = create_history_with_items(items); @@ -750,7 +787,7 @@ fn record_items_truncates_custom_tool_call_output_content() { let long_output = line.repeat(2_500); let item = ResponseItem::CustomToolCallOutput { call_id: "tool-200".to_string(), - output: long_output.clone(), + output: FunctionCallOutputPayload::from_text(long_output.clone()), }; history.record_items([&item], policy); @@ -758,7 +795,8 @@ fn record_items_truncates_custom_tool_call_output_content() { assert_eq!(history.items.len(), 1); match &history.items[0] { ResponseItem::CustomToolCallOutput { output, .. } => { - assert_ne!(output, &long_output); + let output = output.text_content().unwrap_or_default(); + assert_ne!(output, long_output); assert!( output.contains("tokens truncated"), "expected token-based truncation marker, got {output}" @@ -949,7 +987,7 @@ fn normalize_adds_missing_output_for_custom_tool_call() { }, ResponseItem::CustomToolCallOutput { call_id: "tool-x".to_string(), - output: "aborted".to_string(), + output: FunctionCallOutputPayload::from_text("aborted".to_string()), }, ] ); @@ -1016,7 +1054,7 @@ fn normalize_removes_orphan_function_call_output() { fn normalize_removes_orphan_custom_tool_call_output() { let items = vec![ResponseItem::CustomToolCallOutput { call_id: "orphan-2".to_string(), - output: "ok".to_string(), + output: FunctionCallOutputPayload::from_text("ok".to_string()), }]; let mut h = create_history_with_items(items); @@ -1089,7 +1127,7 @@ fn normalize_mixed_inserts_and_removals() { }, ResponseItem::CustomToolCallOutput { call_id: "t1".to_string(), - output: "aborted".to_string(), + output: FunctionCallOutputPayload::from_text("aborted".to_string()), }, ResponseItem::LocalShellCall { id: None, @@ -1191,7 +1229,7 @@ fn normalize_removes_orphan_function_call_output_panics_in_debug() { fn normalize_removes_orphan_custom_tool_call_output_panics_in_debug() { let items = vec![ResponseItem::CustomToolCallOutput { call_id: "orphan-2".to_string(), - output: "ok".to_string(), + output: FunctionCallOutputPayload::from_text("ok".to_string()), }]; let mut h = create_history_with_items(items); h.normalize_history(&default_input_modalities()); @@ -1294,6 +1332,28 @@ fn image_data_url_payload_does_not_dominate_function_call_output_estimate() { assert!(estimated < raw_len); } +#[test] +fn image_data_url_payload_does_not_dominate_custom_tool_call_output_estimate() { + let payload = "C".repeat(50_000); + let image_url = format!("data:image/png;base64,{payload}"); + let item = ResponseItem::CustomToolCallOutput { + call_id: "call-js-repl".to_string(), + output: FunctionCallOutputPayload::from_content_items(vec![ + FunctionCallOutputContentItem::InputText { + text: "Screenshot captured".to_string(), + }, + FunctionCallOutputContentItem::InputImage { image_url }, + ]), + }; + + let raw_len = serde_json::to_string(&item).unwrap().len() as i64; + let estimated = estimate_response_item_model_visible_bytes(&item); + let expected = raw_len - payload.len() as i64 + IMAGE_BYTES_ESTIMATE; + + assert_eq!(estimated, expected); + assert!(estimated < raw_len); +} + #[test] fn non_base64_image_urls_are_unchanged() { let message_item = ResponseItem::Message { diff --git a/codex-rs/core/src/context_manager/normalize.rs b/codex-rs/core/src/context_manager/normalize.rs index a4fe9e64fd3..572ac51fc81 100644 --- a/codex-rs/core/src/context_manager/normalize.rs +++ b/codex-rs/core/src/context_manager/normalize.rs @@ -1,7 +1,6 @@ use std::collections::HashSet; use codex_protocol::models::ContentItem; -use codex_protocol::models::FunctionCallOutputBody; use codex_protocol::models::FunctionCallOutputContentItem; use codex_protocol::models::FunctionCallOutputPayload; use codex_protocol::models::ResponseItem; @@ -35,10 +34,7 @@ pub(crate) fn ensure_call_outputs_present(items: &mut Vec) { idx, ResponseItem::FunctionCallOutput { call_id: call_id.clone(), - output: FunctionCallOutputPayload { - body: FunctionCallOutputBody::Text("aborted".to_string()), - ..Default::default() - }, + output: FunctionCallOutputPayload::from_text("aborted".to_string()), }, )); } @@ -59,7 +55,7 @@ pub(crate) fn ensure_call_outputs_present(items: &mut Vec) { idx, ResponseItem::CustomToolCallOutput { call_id: call_id.clone(), - output: "aborted".to_string(), + output: FunctionCallOutputPayload::from_text("aborted".to_string()), }, )); } @@ -82,10 +78,7 @@ pub(crate) fn ensure_call_outputs_present(items: &mut Vec) { idx, ResponseItem::FunctionCallOutput { call_id: call_id.clone(), - output: FunctionCallOutputPayload { - body: FunctionCallOutputBody::Text("aborted".to_string()), - ..Default::default() - }, + output: FunctionCallOutputPayload::from_text("aborted".to_string()), }, )); } @@ -245,7 +238,8 @@ pub(crate) fn strip_images_when_unsupported( } *content = normalized_content; } - ResponseItem::FunctionCallOutput { output, .. } => { + ResponseItem::FunctionCallOutput { output, .. } + | ResponseItem::CustomToolCallOutput { output, .. } => { if let Some(content_items) = output.content_items_mut() { let mut normalized_content_items = Vec::with_capacity(content_items.len()); for content_item in content_items.iter() { diff --git a/codex-rs/core/src/environment_context.rs b/codex-rs/core/src/environment_context.rs index 8d8d3c6dec4..3e9ed871e30 100644 --- a/codex-rs/core/src/environment_context.rs +++ b/codex-rs/core/src/environment_context.rs @@ -13,7 +13,10 @@ use std::path::PathBuf; pub(crate) struct EnvironmentContext { pub cwd: Option, pub shell: Shell, + pub current_date: Option, + pub timezone: Option, pub network: Option, + pub subagents: Option, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] @@ -23,11 +26,21 @@ pub(crate) struct NetworkContext { } impl EnvironmentContext { - pub fn new(cwd: Option, shell: Shell, network: Option) -> Self { + pub fn new( + cwd: Option, + shell: Shell, + current_date: Option, + timezone: Option, + network: Option, + subagents: Option, + ) -> Self { Self { cwd, shell, + current_date, + timezone, network, + subagents, } } @@ -37,10 +50,17 @@ impl EnvironmentContext { pub fn equals_except_shell(&self, other: &EnvironmentContext) -> bool { let EnvironmentContext { cwd, + current_date, + timezone, network, + subagents, shell: _, } = other; - self.cwd == *cwd && self.network == *network + self.cwd == *cwd + && self.current_date == *current_date + && self.timezone == *timezone + && self.network == *network + && self.subagents == *subagents } pub fn diff_from_turn_context_item( @@ -55,19 +75,24 @@ impl EnvironmentContext { } else { None }; + let current_date = after.current_date.clone(); + let timezone = after.timezone.clone(); let network = if before_network != after_network { after_network } else { before_network }; - EnvironmentContext::new(cwd, shell.clone(), network) + EnvironmentContext::new(cwd, shell.clone(), current_date, timezone, network, None) } pub fn from_turn_context(turn_context: &TurnContext, shell: &Shell) -> Self { Self::new( Some(turn_context.cwd.clone()), shell.clone(), + turn_context.current_date.clone(), + turn_context.timezone.clone(), Self::network_from_turn_context(turn_context), + None, ) } @@ -75,10 +100,20 @@ impl EnvironmentContext { Self::new( Some(turn_context_item.cwd.clone()), shell.clone(), + turn_context_item.current_date.clone(), + turn_context_item.timezone.clone(), Self::network_from_turn_context_item(turn_context_item), + None, ) } + pub fn with_subagents(mut self, subagents: String) -> Self { + if !subagents.is_empty() { + self.subagents = Some(subagents); + } + self + } + fn network_from_turn_context(turn_context: &TurnContext) -> Option { let network = turn_context .config @@ -126,6 +161,12 @@ impl EnvironmentContext { let shell_name = self.shell.name(); lines.push(format!(" {shell_name}")); + if let Some(current_date) = self.current_date { + lines.push(format!(" {current_date}")); + } + if let Some(timezone) = self.timezone { + lines.push(format!(" {timezone}")); + } match self.network { Some(ref network) => { lines.push(" ".to_string()); @@ -142,6 +183,11 @@ impl EnvironmentContext { // lines.push(" ".to_string()); } } + if let Some(subagents) = self.subagents { + lines.push(" ".to_string()); + lines.extend(subagents.lines().map(|line| format!(" {line}"))); + lines.push(" ".to_string()); + } ENVIRONMENT_CONTEXT_FRAGMENT.wrap(lines.join("\n")) } } @@ -171,12 +217,21 @@ mod tests { #[test] fn serialize_workspace_write_environment_context() { let cwd = test_path_buf("/repo"); - let context = EnvironmentContext::new(Some(cwd.clone()), fake_shell(), None); + let context = EnvironmentContext::new( + Some(cwd.clone()), + fake_shell(), + Some("2026-02-26".to_string()), + Some("America/Los_Angeles".to_string()), + None, + None, + ); let expected = format!( r#" {cwd} bash + 2026-02-26 + America/Los_Angeles "#, cwd = cwd.display(), ); @@ -190,13 +245,21 @@ mod tests { allowed_domains: vec!["api.example.com".to_string(), "*.openai.com".to_string()], denied_domains: vec!["blocked.example.com".to_string()], }; - let context = - EnvironmentContext::new(Some(test_path_buf("/repo")), fake_shell(), Some(network)); + let context = EnvironmentContext::new( + Some(test_path_buf("/repo")), + fake_shell(), + Some("2026-02-26".to_string()), + Some("America/Los_Angeles".to_string()), + Some(network), + None, + ); let expected = format!( r#" {} bash + 2026-02-26 + America/Los_Angeles api.example.com *.openai.com @@ -211,10 +274,19 @@ mod tests { #[test] fn serialize_read_only_environment_context() { - let context = EnvironmentContext::new(None, fake_shell(), None); + let context = EnvironmentContext::new( + None, + fake_shell(), + Some("2026-02-26".to_string()), + Some("America/Los_Angeles".to_string()), + None, + None, + ); let expected = r#" bash + 2026-02-26 + America/Los_Angeles "#; assert_eq!(context.serialize_to_xml(), expected); @@ -222,10 +294,19 @@ mod tests { #[test] fn serialize_external_sandbox_environment_context() { - let context = EnvironmentContext::new(None, fake_shell(), None); + let context = EnvironmentContext::new( + None, + fake_shell(), + Some("2026-02-26".to_string()), + Some("America/Los_Angeles".to_string()), + None, + None, + ); let expected = r#" bash + 2026-02-26 + America/Los_Angeles "#; assert_eq!(context.serialize_to_xml(), expected); @@ -233,10 +314,19 @@ mod tests { #[test] fn serialize_external_sandbox_with_restricted_network_environment_context() { - let context = EnvironmentContext::new(None, fake_shell(), None); + let context = EnvironmentContext::new( + None, + fake_shell(), + Some("2026-02-26".to_string()), + Some("America/Los_Angeles".to_string()), + None, + None, + ); let expected = r#" bash + 2026-02-26 + America/Los_Angeles "#; assert_eq!(context.serialize_to_xml(), expected); @@ -244,10 +334,19 @@ mod tests { #[test] fn serialize_full_access_environment_context() { - let context = EnvironmentContext::new(None, fake_shell(), None); + let context = EnvironmentContext::new( + None, + fake_shell(), + Some("2026-02-26".to_string()), + Some("America/Los_Angeles".to_string()), + None, + None, + ); let expected = r#" bash + 2026-02-26 + America/Los_Angeles "#; assert_eq!(context.serialize_to_xml(), expected); @@ -255,23 +354,65 @@ mod tests { #[test] fn equals_except_shell_compares_cwd() { - let context1 = EnvironmentContext::new(Some(PathBuf::from("/repo")), fake_shell(), None); - let context2 = EnvironmentContext::new(Some(PathBuf::from("/repo")), fake_shell(), None); + let context1 = EnvironmentContext::new( + Some(PathBuf::from("/repo")), + fake_shell(), + None, + None, + None, + None, + ); + let context2 = EnvironmentContext::new( + Some(PathBuf::from("/repo")), + fake_shell(), + None, + None, + None, + None, + ); assert!(context1.equals_except_shell(&context2)); } #[test] fn equals_except_shell_ignores_sandbox_policy() { - let context1 = EnvironmentContext::new(Some(PathBuf::from("/repo")), fake_shell(), None); - let context2 = EnvironmentContext::new(Some(PathBuf::from("/repo")), fake_shell(), None); + let context1 = EnvironmentContext::new( + Some(PathBuf::from("/repo")), + fake_shell(), + None, + None, + None, + None, + ); + let context2 = EnvironmentContext::new( + Some(PathBuf::from("/repo")), + fake_shell(), + None, + None, + None, + None, + ); assert!(context1.equals_except_shell(&context2)); } #[test] fn equals_except_shell_compares_cwd_differences() { - let context1 = EnvironmentContext::new(Some(PathBuf::from("/repo1")), fake_shell(), None); - let context2 = EnvironmentContext::new(Some(PathBuf::from("/repo2")), fake_shell(), None); + let context1 = EnvironmentContext::new( + Some(PathBuf::from("/repo1")), + fake_shell(), + None, + None, + None, + None, + ); + let context2 = EnvironmentContext::new( + Some(PathBuf::from("/repo2")), + fake_shell(), + None, + None, + None, + None, + ); assert!(!context1.equals_except_shell(&context2)); } @@ -286,6 +427,9 @@ mod tests { shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), }, None, + None, + None, + None, ); let context2 = EnvironmentContext::new( Some(PathBuf::from("/repo")), @@ -295,8 +439,39 @@ mod tests { shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), }, None, + None, + None, + None, ); assert!(context1.equals_except_shell(&context2)); } + + #[test] + fn serialize_environment_context_with_subagents() { + let context = EnvironmentContext::new( + Some(test_path_buf("/repo")), + fake_shell(), + Some("2026-02-26".to_string()), + Some("America/Los_Angeles".to_string()), + None, + Some("- agent-1: atlas\n- agent-2".to_string()), + ); + + let expected = format!( + r#" + {} + bash + 2026-02-26 + America/Los_Angeles + + - agent-1: atlas + - agent-2 + +"#, + test_path_buf("/repo").display() + ); + + assert_eq!(context.serialize_to_xml(), expected); + } } diff --git a/codex-rs/core/src/exec.rs b/codex-rs/core/src/exec.rs index 288ee7f2127..5a805a3b8b4 100644 --- a/codex-rs/core/src/exec.rs +++ b/codex-rs/core/src/exec.rs @@ -219,6 +219,8 @@ pub async fn process_exec_tool_call( enforce_managed_network, network: network.as_ref(), sandbox_policy_cwd: sandbox_cwd, + #[cfg(target_os = "macos")] + macos_seatbelt_profile_extensions: None, codex_linux_sandbox_exe: codex_linux_sandbox_exe.as_ref(), use_linux_sandbox_bwrap, windows_sandbox_level, diff --git a/codex-rs/core/src/mcp/mod.rs b/codex-rs/core/src/mcp/mod.rs index 3abe4e9233c..d744010469b 100644 --- a/codex-rs/core/src/mcp/mod.rs +++ b/codex-rs/core/src/mcp/mod.rs @@ -139,6 +139,7 @@ fn codex_apps_mcp_server_config(config: &Config, auth: Option<&CodexAuth>) -> Mc enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, } } diff --git a/codex-rs/core/src/mcp/skill_dependencies.rs b/codex-rs/core/src/mcp/skill_dependencies.rs index 1969bc6322b..8d5fc3c1a35 100644 --- a/codex-rs/core/src/mcp/skill_dependencies.rs +++ b/codex-rs/core/src/mcp/skill_dependencies.rs @@ -241,6 +241,7 @@ pub(crate) async fn maybe_install_mcp_dependencies( oauth_config.http_headers, oauth_config.env_http_headers, &[], + server_config.oauth_resource.as_deref(), config.mcp_oauth_callback_port, config.mcp_oauth_callback_url.as_deref(), ) @@ -387,6 +388,7 @@ fn mcp_dependency_to_server_config( enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }); } @@ -411,6 +413,7 @@ fn mcp_dependency_to_server_config( enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }); } @@ -468,6 +471,7 @@ mod tests { enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, )]); @@ -516,6 +520,7 @@ mod tests { enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, )]); diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 12e3d325d34..2d22351d1f8 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -2115,6 +2115,7 @@ mod tests { enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, auth_status: McpAuthStatus::Unsupported, }; @@ -2162,6 +2163,7 @@ mod tests { enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, auth_status: McpAuthStatus::Unsupported, }; diff --git a/codex-rs/core/src/memories/README.md b/codex-rs/core/src/memories/README.md index afbc94e4d9d..0a49b58316b 100644 --- a/codex-rs/core/src/memories/README.md +++ b/codex-rs/core/src/memories/README.md @@ -59,7 +59,14 @@ Phase 2 consolidates the latest stage-1 outputs into the filesystem memory artif What it does: - claims a single global phase-2 job (so only one consolidation runs at a time) -- loads a bounded set of the most recent stage-1 outputs from the state DB (the per-rollout memories produced by Phase 1, used as the consolidation input set) +- loads a bounded set of stage-1 outputs from the state DB using phase-2 + selection rules: + - ignores memories whose `last_usage` falls outside the configured + `max_unused_days` window + - for memories with no `last_usage`, falls back to `generated_at` so fresh + never-used memories can still be selected + - ranks eligible memories by `usage_count` first, then by the most recent + `last_usage` / `generated_at` - computes a completion watermark from the claimed watermark + newest input timestamps - syncs local memory artifacts under the memories root: - `raw_memories.md` (merged raw memories, latest first) diff --git a/codex-rs/core/src/memories/phase2.rs b/codex-rs/core/src/memories/phase2.rs index f86b3bbdeec..fb6e99d2df7 100644 --- a/codex-rs/core/src/memories/phase2.rs +++ b/codex-rs/core/src/memories/phase2.rs @@ -53,6 +53,7 @@ pub(super) async fn run(session: &Arc, config: Arc) { }; let root = memory_root(&config.codex_home); let max_raw_memories = config.memories.max_raw_memories_for_global; + let max_unused_days = config.memories.max_unused_days; // 1. Claim the job. let claim = match job::claim(session, db).await { @@ -76,7 +77,10 @@ pub(super) async fn run(session: &Arc, config: Arc) { }; // 3. Query the memories - let selection = match db.get_phase2_input_selection(max_raw_memories).await { + let selection = match db + .get_phase2_input_selection(max_raw_memories, max_unused_days) + .await + { Ok(selection) => selection, Err(err) => { tracing::error!("failed to list stage1 outputs from global: {}", err); diff --git a/codex-rs/core/src/memories/storage.rs b/codex-rs/core/src/memories/storage.rs index bd3eaa23aa1..1410c4e73a9 100644 --- a/codex-rs/core/src/memories/storage.rs +++ b/codex-rs/core/src/memories/storage.rs @@ -143,6 +143,9 @@ async fn write_rollout_summary_for_thread( writeln!(body, "rollout_path: {}", memory.rollout_path.display()) .map_err(rollout_summary_format_error)?; writeln!(body, "cwd: {}", memory.cwd.display()).map_err(rollout_summary_format_error)?; + if let Some(git_branch) = memory.git_branch.as_deref() { + writeln!(body, "git_branch: {git_branch}").map_err(rollout_summary_format_error)?; + } writeln!(body).map_err(rollout_summary_format_error)?; body.push_str(&memory.rollout_summary); body.push('\n'); @@ -273,6 +276,7 @@ mod tests { rollout_slug: rollout_slug.map(ToString::to_string), rollout_path: PathBuf::from("/tmp/rollout.jsonl"), cwd: PathBuf::from("/tmp/workspace"), + git_branch: None, generated_at: Utc.timestamp_opt(124, 0).single().expect("timestamp"), } } diff --git a/codex-rs/core/src/memories/tests.rs b/codex-rs/core/src/memories/tests.rs index 8f5fd5a81ed..62cef7eae35 100644 --- a/codex-rs/core/src/memories/tests.rs +++ b/codex-rs/core/src/memories/tests.rs @@ -88,6 +88,7 @@ async fn sync_rollout_summaries_and_raw_memories_file_keeps_latest_memories_only rollout_slug: None, rollout_path: PathBuf::from("/tmp/rollout-100.jsonl"), cwd: PathBuf::from("/tmp/workspace"), + git_branch: None, generated_at: Utc.timestamp_opt(101, 0).single().expect("timestamp"), }]; @@ -193,6 +194,7 @@ async fn sync_rollout_summaries_uses_timestamp_hash_and_sanitized_slug_filename( rollout_slug: Some("Unsafe Slug/With Spaces & Symbols + EXTRA_LONG_12345".to_string()), rollout_path: PathBuf::from("/tmp/rollout-200.jsonl"), cwd: PathBuf::from("/tmp/workspace"), + git_branch: Some("feature/memory-branch".to_string()), generated_at: Utc.timestamp_opt(201, 0).single().expect("timestamp"), }]; @@ -248,6 +250,7 @@ async fn sync_rollout_summaries_uses_timestamp_hash_and_sanitized_slug_filename( .expect("read rollout summary"); assert!(summary.contains(&format!("thread_id: {thread_id}"))); assert!(summary.contains("rollout_path: /tmp/rollout-200.jsonl")); + assert!(summary.contains("git_branch: feature/memory-branch")); assert!( !tokio::fs::try_exists(&stale_unslugged_path) .await @@ -294,6 +297,7 @@ task_outcome: success rollout_slug: Some("Unsafe Slug/With Spaces & Symbols + EXTRA_LONG_12345".to_string()), rollout_path: PathBuf::from("/tmp/rollout-200.jsonl"), cwd: PathBuf::from("/tmp/workspace"), + git_branch: None, generated_at: Utc.timestamp_opt(201, 0).single().expect("timestamp"), }]; @@ -378,6 +382,7 @@ mod phase2 { rollout_slug: None, rollout_path: PathBuf::from("/tmp/rollout-summary.jsonl"), cwd: PathBuf::from("/tmp/workspace"), + git_branch: None, generated_at: chrono::DateTime::::from_timestamp(source_updated_at + 1, 0) .expect("valid generated_at timestamp"), } @@ -559,7 +564,7 @@ mod phase2 { #[tokio::test] async fn dispatch_reclaims_stale_global_lock_and_starts_consolidation() { let harness = DispatchHarness::new().await; - harness.seed_stage1_output(100).await; + harness.seed_stage1_output(Utc::now().timestamp()).await; let stale_claim = harness .state_db @@ -573,12 +578,18 @@ mod phase2 { phase2::run(&harness.session, Arc::clone(&harness.config)).await; - let running_claim = harness + let post_dispatch_claim = harness .state_db .try_claim_global_phase2_job(ThreadId::new(), 3_600) .await - .expect("claim while running"); - pretty_assertions::assert_eq!(running_claim, Phase2JobClaimOutcome::SkippedRunning); + .expect("claim after stale lock dispatch"); + assert!( + matches!( + post_dispatch_claim, + Phase2JobClaimOutcome::SkippedRunning | Phase2JobClaimOutcome::SkippedNotDirty + ), + "stale-lock dispatch should either keep the reclaimed job running or finish it before re-claim" + ); let user_input_ops = harness.user_input_ops_count(); pretty_assertions::assert_eq!(user_input_ops, 1); diff --git a/codex-rs/core/src/models_manager/model_info.rs b/codex-rs/core/src/models_manager/model_info.rs index 19a945f951f..4824e4cd111 100644 --- a/codex-rs/core/src/models_manager/model_info.rs +++ b/codex-rs/core/src/models_manager/model_info.rs @@ -1,3 +1,4 @@ +use codex_protocol::config_types::ReasoningSummary; use codex_protocol::openai_models::ConfigShellToolType; use codex_protocol::openai_models::ModelInfo; use codex_protocol::openai_models::ModelInstructionsVariables; @@ -68,10 +69,12 @@ pub(crate) fn model_info_from_slug(slug: &str) -> ModelInfo { visibility: ModelVisibility::None, supported_in_api: true, priority: 99, + availability_nux: None, upgrade: None, base_instructions: BASE_INSTRUCTIONS.to_string(), model_messages: local_personality_messages_for_slug(slug), supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::Auto, support_verbosity: false, default_verbosity: None, apply_patch_tool_type: None, diff --git a/codex-rs/core/src/realtime_conversation.rs b/codex-rs/core/src/realtime_conversation.rs index 1b858aedcfc..d643d746825 100644 --- a/codex-rs/core/src/realtime_conversation.rs +++ b/codex-rs/core/src/realtime_conversation.rs @@ -31,6 +31,7 @@ use tokio::sync::Mutex; use tokio::task::JoinHandle; use tracing::debug; use tracing::error; +use tracing::info; use tracing::warn; const AUDIO_IN_QUEUE_CAPACITY: usize = 256; @@ -184,6 +185,7 @@ pub(crate) async fn handle_start( let requested_session_id = params .session_id .or_else(|| Some(sess.conversation_id.to_string())); + info!("starting realtime conversation"); let events_rx = match sess .conversation .start(api_provider, None, prompt, requested_session_id.clone()) @@ -191,11 +193,14 @@ pub(crate) async fn handle_start( { Ok(events_rx) => events_rx, Err(err) => { + error!("failed to start realtime conversation: {err}"); send_conversation_error(sess, sub_id, err.to_string(), CodexErrorInfo::Other).await; return Ok(()); } }; + info!("realtime conversation started"); + sess.send_event_raw(Event { id: sub_id.clone(), msg: EventMsg::RealtimeConversationStarted(RealtimeConversationStartedEvent { @@ -211,6 +216,7 @@ pub(crate) async fn handle_start( msg, }; while let Ok(event) = events_rx.recv().await { + debug!(conversation_id = %sess_clone.conversation_id, "received realtime conversation event"); let maybe_routed_text = match &event { RealtimeEvent::ConversationItemAdded(item) => { realtime_text_from_conversation_item(item) @@ -231,6 +237,7 @@ pub(crate) async fn handle_start( .await; } if let Some(()) = sess_clone.conversation.running_state().await { + info!("realtime conversation transport closed"); sess_clone .send_event_raw(ev(EventMsg::RealtimeConversationClosed( RealtimeConversationClosedEvent { @@ -250,6 +257,7 @@ pub(crate) async fn handle_audio( params: ConversationAudioParams, ) { if let Err(err) = sess.conversation.audio_in(params.frame).await { + error!("failed to append realtime audio: {err}"); send_conversation_error(sess, sub_id, err.to_string(), CodexErrorInfo::BadRequest).await; } } @@ -284,6 +292,7 @@ pub(crate) async fn handle_text( debug!(text = %params.text, "[realtime-text] appending realtime conversation text input"); if let Err(err) = sess.conversation.text_in(params.text).await { + error!("failed to append realtime text: {err}"); send_conversation_error(sess, sub_id, err.to_string(), CodexErrorInfo::BadRequest).await; } } diff --git a/codex-rs/core/src/rollout/recorder.rs b/codex-rs/core/src/rollout/recorder.rs index d612bd14e58..bc2421927bd 100644 --- a/codex-rs/core/src/rollout/recorder.rs +++ b/codex-rs/core/src/rollout/recorder.rs @@ -326,7 +326,9 @@ impl RolloutRecorder { else { break; }; - if let Some(path) = select_resume_path_from_db_page(&db_page, filter_cwd) { + if let Some(path) = + select_resume_path_from_db_page(&db_page, filter_cwd, default_provider).await + { return Ok(Some(path)); } db_cursor = db_page.next_anchor.map(Into::into); @@ -348,7 +350,7 @@ impl RolloutRecorder { default_provider, ) .await?; - if let Some(path) = select_resume_path(&page, filter_cwd) { + if let Some(path) = select_resume_path(&page, filter_cwd, default_provider).await { return Ok(Some(path)); } cursor = page.next_cursor; @@ -961,35 +963,79 @@ impl From for ThreadsPage { } } -fn select_resume_path(page: &ThreadsPage, filter_cwd: Option<&Path>) -> Option { +async fn select_resume_path( + page: &ThreadsPage, + filter_cwd: Option<&Path>, + default_provider: &str, +) -> Option { match filter_cwd { - Some(cwd) => page.items.iter().find_map(|item| { - if item - .cwd - .as_ref() - .is_some_and(|session_cwd| cwd_matches(session_cwd, cwd)) - { - Some(item.path.clone()) - } else { - None + Some(cwd) => { + for item in &page.items { + if resume_candidate_matches_cwd( + item.path.as_path(), + item.cwd.as_deref(), + cwd, + default_provider, + ) + .await + { + return Some(item.path.clone()); + } } - }), + None + } None => page.items.first().map(|item| item.path.clone()), } } -fn select_resume_path_from_db_page( +async fn resume_candidate_matches_cwd( + rollout_path: &Path, + cached_cwd: Option<&Path>, + cwd: &Path, + default_provider: &str, +) -> bool { + if cached_cwd.is_some_and(|session_cwd| cwd_matches(session_cwd, cwd)) { + return true; + } + + if let Ok((items, _, _)) = RolloutRecorder::load_rollout_items(rollout_path).await + && let Some(latest_turn_context_cwd) = items.iter().rev().find_map(|item| match item { + RolloutItem::TurnContext(turn_context) => Some(turn_context.cwd.as_path()), + RolloutItem::SessionMeta(_) + | RolloutItem::ResponseItem(_) + | RolloutItem::Compacted(_) + | RolloutItem::EventMsg(_) => None, + }) + { + return cwd_matches(latest_turn_context_cwd, cwd); + } + + metadata::extract_metadata_from_rollout(rollout_path, default_provider, None) + .await + .is_ok_and(|outcome| cwd_matches(outcome.metadata.cwd.as_path(), cwd)) +} + +async fn select_resume_path_from_db_page( page: &codex_state::ThreadsPage, filter_cwd: Option<&Path>, + default_provider: &str, ) -> Option { match filter_cwd { - Some(cwd) => page.items.iter().find_map(|item| { - if cwd_matches(item.cwd.as_path(), cwd) { - Some(item.rollout_path.clone()) - } else { - None + Some(cwd) => { + for item in &page.items { + if resume_candidate_matches_cwd( + item.rollout_path.as_path(), + Some(item.cwd.as_path()), + cwd, + default_provider, + ) + .await + { + return Some(item.rollout_path.clone()); + } } - }), + None + } None => page.items.first().map(|item| item.rollout_path.clone()), } } @@ -1010,8 +1056,12 @@ mod tests { use crate::config::ConfigBuilder; use crate::features::Feature; use chrono::TimeZone; + use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; use codex_protocol::protocol::AgentMessageEvent; + use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; + use codex_protocol::protocol::SandboxPolicy; + use codex_protocol::protocol::TurnContextItem; use codex_protocol::protocol::UserMessageEvent; use pretty_assertions::assert_eq; use std::fs::File; @@ -1320,4 +1370,49 @@ mod tests { assert_eq!(repaired_path, Some(real_path)); Ok(()) } + + #[tokio::test] + async fn resume_candidate_matches_cwd_reads_latest_turn_context() -> std::io::Result<()> { + let home = TempDir::new().expect("temp dir"); + let stale_cwd = home.path().join("stale"); + let latest_cwd = home.path().join("latest"); + fs::create_dir_all(&stale_cwd)?; + fs::create_dir_all(&latest_cwd)?; + + let path = write_session_file(home.path(), "2025-01-03T13-00-00", Uuid::from_u128(9012))?; + let mut file = std::fs::OpenOptions::new().append(true).open(&path)?; + let turn_context = RolloutLine { + timestamp: "2025-01-03T13:00:01Z".to_string(), + item: RolloutItem::TurnContext(TurnContextItem { + turn_id: Some("turn-1".to_string()), + cwd: latest_cwd.clone(), + current_date: None, + timezone: None, + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::new_read_only_policy(), + network: None, + model: "test-model".to_string(), + personality: None, + collaboration_mode: None, + effort: None, + summary: ReasoningSummaryConfig::Auto, + user_instructions: None, + developer_instructions: None, + final_output_json_schema: None, + truncation_policy: None, + }), + }; + writeln!(file, "{}", serde_json::to_string(&turn_context)?)?; + + assert!( + resume_candidate_matches_cwd( + path.as_path(), + Some(stale_cwd.as_path()), + latest_cwd.as_path(), + "test-provider", + ) + .await + ); + Ok(()) + } } diff --git a/codex-rs/core/src/sandboxing/mod.rs b/codex-rs/core/src/sandboxing/mod.rs index 56c7bff6a68..da03f00a62c 100644 --- a/codex-rs/core/src/sandboxing/mod.rs +++ b/codex-rs/core/src/sandboxing/mod.rs @@ -17,7 +17,7 @@ use crate::protocol::SandboxPolicy; #[cfg(target_os = "macos")] use crate::seatbelt::MACOS_PATH_TO_SEATBELT_EXECUTABLE; #[cfg(target_os = "macos")] -use crate::seatbelt::create_seatbelt_command_args; +use crate::seatbelt::create_seatbelt_command_args_with_extensions; #[cfg(target_os = "macos")] use crate::spawn::CODEX_SANDBOX_ENV_VAR; use crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; @@ -25,6 +25,8 @@ use crate::tools::sandboxing::SandboxablePreference; use codex_network_proxy::NetworkProxy; use codex_protocol::config_types::WindowsSandboxLevel; use codex_protocol::models::FileSystemPermissions; +#[cfg(target_os = "macos")] +use codex_protocol::models::MacOsSeatbeltProfileExtensions; use codex_protocol::models::PermissionProfile; pub use codex_protocol::models::SandboxPermissions; use codex_protocol::protocol::ReadOnlyAccess; @@ -73,6 +75,8 @@ pub(crate) struct SandboxTransformRequest<'a> { // to make shared ownership explicit across runtime/sandbox plumbing. pub network: Option<&'a NetworkProxy>, pub sandbox_policy_cwd: &'a Path, + #[cfg(target_os = "macos")] + pub macos_seatbelt_profile_extensions: Option<&'a MacOsSeatbeltProfileExtensions>, pub codex_linux_sandbox_exe: Option<&'a PathBuf>, pub use_linux_sandbox_bwrap: bool, pub windows_sandbox_level: WindowsSandboxLevel, @@ -342,6 +346,8 @@ impl SandboxManager { enforce_managed_network, network, sandbox_policy_cwd, + #[cfg(target_os = "macos")] + macos_seatbelt_profile_extensions, codex_linux_sandbox_exe, use_linux_sandbox_bwrap, windows_sandbox_level, @@ -370,12 +376,13 @@ impl SandboxManager { SandboxType::MacosSeatbelt => { let mut seatbelt_env = HashMap::new(); seatbelt_env.insert(CODEX_SANDBOX_ENV_VAR.to_string(), "seatbelt".to_string()); - let mut args = create_seatbelt_command_args( + let mut args = create_seatbelt_command_args_with_extensions( command.clone(), &effective_policy, sandbox_policy_cwd, enforce_managed_network, network, + macos_seatbelt_profile_extensions, ); let mut full_command = Vec::with_capacity(1 + args.len()); full_command.push(MACOS_PATH_TO_SEATBELT_EXECUTABLE.to_string()); diff --git a/codex-rs/core/src/seatbelt_permissions.rs b/codex-rs/core/src/seatbelt_permissions.rs index aacc3a77843..6b9fa681884 100644 --- a/codex-rs/core/src/seatbelt_permissions.rs +++ b/codex-rs/core/src/seatbelt_permissions.rs @@ -3,34 +3,9 @@ use std::collections::BTreeSet; use std::path::PathBuf; -#[allow(dead_code)] -#[derive(Debug, Clone, PartialEq, Eq, Default)] -pub enum MacOsPreferencesPermission { - // IMPORTANT: ReadOnly needs to be the default because it's the security-sensitive default. - // it's important for allowing cf prefs to work. - #[default] - ReadOnly, - ReadWrite, - None, -} - -#[allow(dead_code)] -#[derive(Debug, Clone, PartialEq, Eq, Default)] -pub enum MacOsAutomationPermission { - #[default] - None, - All, - BundleIds(Vec), -} - -#[allow(dead_code)] -#[derive(Debug, Clone, PartialEq, Eq, Default)] -pub struct MacOsSeatbeltProfileExtensions { - pub macos_preferences: MacOsPreferencesPermission, - pub macos_automation: MacOsAutomationPermission, - pub macos_accessibility: bool, - pub macos_calendar: bool, -} +pub use codex_protocol::models::MacOsAutomationPermission; +pub use codex_protocol::models::MacOsPreferencesPermission; +pub use codex_protocol::models::MacOsSeatbeltProfileExtensions; #[derive(Debug, Clone, PartialEq, Eq, Default)] pub(crate) struct SeatbeltExtensionPolicy { @@ -38,25 +13,26 @@ pub(crate) struct SeatbeltExtensionPolicy { pub(crate) dir_params: Vec<(String, PathBuf)>, } -impl MacOsSeatbeltProfileExtensions { - pub fn normalized(&self) -> Self { - let mut normalized = self.clone(); - if let MacOsAutomationPermission::BundleIds(bundle_ids) = &self.macos_automation { - let bundle_ids = normalize_bundle_ids(bundle_ids); - normalized.macos_automation = if bundle_ids.is_empty() { - MacOsAutomationPermission::None - } else { - MacOsAutomationPermission::BundleIds(bundle_ids) - }; - } - normalized +fn normalized_extensions( + extensions: &MacOsSeatbeltProfileExtensions, +) -> MacOsSeatbeltProfileExtensions { + let mut normalized = extensions.clone(); + if let MacOsAutomationPermission::BundleIds(bundle_ids) = &extensions.macos_automation { + let bundle_ids = normalize_bundle_ids(bundle_ids); + normalized.macos_automation = if bundle_ids.is_empty() { + MacOsAutomationPermission::None + } else { + MacOsAutomationPermission::BundleIds(bundle_ids) + }; } + + normalized } pub(crate) fn build_seatbelt_extensions( extensions: &MacOsSeatbeltProfileExtensions, ) -> SeatbeltExtensionPolicy { - let extensions = extensions.normalized(); + let extensions = normalized_extensions(extensions); let mut clauses = Vec::new(); match extensions.macos_preferences { diff --git a/codex-rs/core/src/session_prefix.rs b/codex-rs/core/src/session_prefix.rs index ebf068894ad..db3ac00a6dc 100644 --- a/codex-rs/core/src/session_prefix.rs +++ b/codex-rs/core/src/session_prefix.rs @@ -12,3 +12,10 @@ pub(crate) fn format_subagent_notification_message(agent_id: &str, status: &Agen .to_string(); SUBAGENT_NOTIFICATION_FRAGMENT.wrap(payload_json) } + +pub(crate) fn format_subagent_context_line(agent_id: &str, agent_nickname: Option<&str>) -> String { + match agent_nickname.filter(|nickname| !nickname.is_empty()) { + Some(agent_nickname) => format!("- {agent_id}: {agent_nickname}"), + None => format!("- {agent_id}"), + } +} diff --git a/codex-rs/core/src/skills/permissions.rs b/codex-rs/core/src/skills/permissions.rs index 4019a310799..89c49709e37 100644 --- a/codex-rs/core/src/skills/permissions.rs +++ b/codex-rs/core/src/skills/permissions.rs @@ -8,6 +8,7 @@ use codex_protocol::models::MacOsAutomationValue; use codex_protocol::models::MacOsPermissions; #[cfg(target_os = "macos")] use codex_protocol::models::MacOsPreferencesValue; +use codex_protocol::models::MacOsSeatbeltProfileExtensions; use codex_protocol::models::PermissionProfile; use codex_utils_absolute_path::AbsolutePathBuf; use dirs::home_dir; @@ -20,10 +21,6 @@ use crate::config::types::ShellEnvironmentPolicy; use crate::protocol::AskForApproval; use crate::protocol::ReadOnlyAccess; use crate::protocol::SandboxPolicy; -#[cfg(target_os = "macos")] -use crate::seatbelt_permissions::MacOsSeatbeltProfileExtensions; -#[cfg(not(target_os = "macos"))] -type MacOsSeatbeltProfileExtensions = (); pub(crate) fn compile_permission_profile( skill_dir: &Path, diff --git a/codex-rs/core/src/thread_manager.rs b/codex-rs/core/src/thread_manager.rs index f653b3dfc00..0a56cab1d66 100644 --- a/codex-rs/core/src/thread_manager.rs +++ b/codex-rs/core/src/thread_manager.rs @@ -252,7 +252,7 @@ impl ThreadManager { } pub async fn list_thread_ids(&self) -> Vec { - self.state.threads.read().await.keys().copied().collect() + self.state.list_thread_ids().await } pub async fn refresh_mcp_servers(&self, refresh_config: McpServerRefreshConfig) { @@ -412,6 +412,10 @@ impl ThreadManager { } impl ThreadManagerState { + pub(crate) async fn list_thread_ids(&self) -> Vec { + self.threads.read().await.keys().copied().collect() + } + /// Fetch a thread by ID or return ThreadNotFound. pub(crate) async fn get_thread(&self, thread_id: ThreadId) -> CodexResult> { let threads = self.threads.read().await; @@ -495,6 +499,27 @@ impl ThreadManagerState { .await } + pub(crate) async fn fork_thread_with_source( + &self, + config: Config, + initial_history: InitialHistory, + agent_control: AgentControl, + session_source: SessionSource, + persist_extended_history: bool, + ) -> CodexResult { + self.spawn_thread_with_source( + config, + initial_history, + Arc::clone(&self.auth_manager), + agent_control, + session_source, + Vec::new(), + persist_extended_history, + None, + ) + .await + } + /// Spawn a new thread with optional history and register it with the manager. #[allow(clippy::too_many_arguments)] pub(crate) async fn spawn_thread( diff --git a/codex-rs/core/src/tools/context.rs b/codex-rs/core/src/tools/context.rs index 58925622a0a..0700b4d013c 100644 --- a/codex-rs/core/src/tools/context.rs +++ b/codex-rs/core/src/tools/context.rs @@ -95,15 +95,12 @@ impl ToolOutput { match self { ToolOutput::Function { body, success } => { // `custom_tool_call` is the Responses API item type for freeform - // tools (`ToolSpec::Freeform`, e.g. freeform `apply_patch`). - // Those payloads must round-trip as `custom_tool_call_output` - // with plain string output. + // tools (`ToolSpec::Freeform`, e.g. freeform `apply_patch` or + // `js_repl`). if matches!(payload, ToolPayload::Custom { .. }) { - // Freeform/custom tools (`custom_tool_call`) use the custom - // output wire shape and remain string-only. return ResponseInputItem::CustomToolCallOutput { call_id: call_id.to_string(), - output: body.to_text().unwrap_or_default(), + output: FunctionCallOutputPayload { body, success }, }; } @@ -183,7 +180,9 @@ mod tests { match response { ResponseInputItem::CustomToolCallOutput { call_id, output } => { assert_eq!(call_id, "call-42"); - assert_eq!(output, "patched"); + assert_eq!(output.text_content(), Some("patched")); + assert!(output.content_items().is_none()); + assert_eq!(output.success, Some(true)); } other => panic!("expected CustomToolCallOutput, got {other:?}"), } @@ -234,8 +233,21 @@ mod tests { match response { ResponseInputItem::CustomToolCallOutput { call_id, output } => { + let expected = vec![ + FunctionCallOutputContentItem::InputText { + text: "line 1".to_string(), + }, + FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,AAA".to_string(), + }, + FunctionCallOutputContentItem::InputText { + text: "line 2".to_string(), + }, + ]; assert_eq!(call_id, "call-99"); - assert_eq!(output, "line 1\nline 2"); + assert_eq!(output.content_items(), Some(expected.as_slice())); + assert_eq!(output.body.to_text().as_deref(), Some("line 1\nline 2")); + assert_eq!(output.success, Some(true)); } other => panic!("expected CustomToolCallOutput, got {other:?}"), } diff --git a/codex-rs/core/src/tools/handlers/js_repl.rs b/codex-rs/core/src/tools/handlers/js_repl.rs index 4488b4ea51e..362d25b81e4 100644 --- a/codex-rs/core/src/tools/handlers/js_repl.rs +++ b/codex-rs/core/src/tools/handlers/js_repl.rs @@ -155,9 +155,13 @@ impl ToolHandler for JsReplHandler { }; let content = result.output; - let items = vec![FunctionCallOutputContentItem::InputText { - text: content.clone(), - }]; + let mut items = Vec::with_capacity(result.content_items.len() + 1); + if !content.is_empty() { + items.push(FunctionCallOutputContentItem::InputText { + text: content.clone(), + }); + } + items.extend(result.content_items); emit_js_repl_exec_end( session.as_ref(), @@ -170,7 +174,11 @@ impl ToolHandler for JsReplHandler { .await; Ok(ToolOutput::Function { - body: FunctionCallOutputBody::ContentItems(items), + body: if items.is_empty() { + FunctionCallOutputBody::Text(content) + } else { + FunctionCallOutputBody::ContentItems(items) + }, success: Some(true), }) } diff --git a/codex-rs/core/src/tools/handlers/multi_agents.rs b/codex-rs/core/src/tools/handlers/multi_agents.rs index 6e62fd04fd1..16ff943ab6e 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents.rs @@ -91,6 +91,7 @@ impl ToolHandler for MultiAgentHandler { mod spawn { use super::*; + use crate::agent::control::SpawnAgentOptions; use crate::agent::role::DEFAULT_ROLE_NAME; use crate::agent::role::apply_role_to_config; @@ -103,6 +104,8 @@ mod spawn { message: Option, items: Option>, agent_type: Option, + #[serde(default)] + fork_context: bool, } #[derive(Debug, Serialize)] @@ -155,7 +158,7 @@ mod spawn { let result = session .services .agent_control - .spawn_agent( + .spawn_agent_with_options( config, input_items, Some(thread_spawn_source( @@ -163,6 +166,9 @@ mod spawn { child_depth, role_name, )), + SpawnAgentOptions { + fork_parent_spawn_call_id: args.fork_context.then(|| call_id.clone()), + }, ) .await .map_err(collab_spawn_error); @@ -914,7 +920,7 @@ fn build_agent_shared_config(turn: &TurnContext) -> Result, } struct KernelState { @@ -119,10 +126,37 @@ struct ExecContext { #[derive(Default)] struct ExecToolCalls { in_flight: usize, + content_items: Vec, notify: Arc, cancel: CancellationToken, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[allow(clippy::enum_variant_names)] +enum JsReplToolCallPayloadKind { + MessageContent, + FunctionText, + FunctionContentItems, + CustomText, + CustomContentItems, + McpResult, + McpErrorResult, + Error, +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +struct JsReplToolCallResponseSummary { + response_type: Option, + payload_kind: Option, + payload_text_preview: Option, + payload_text_length: Option, + payload_item_count: Option, + text_item_count: Option, + image_item_count: Option, + structured_content_present: Option, + result_is_error: Option, +} + enum KernelStreamEnd { Shutdown, StdoutEof, @@ -338,6 +372,21 @@ impl JsReplManager { Some(state.cancel.clone()) } + async fn record_exec_tool_call_content_items( + exec_tool_calls: &Arc>>, + exec_id: &str, + content_items: Vec, + ) { + if content_items.is_empty() { + return; + } + + let mut calls = exec_tool_calls.lock().await; + if let Some(state) = calls.get_mut(exec_id) { + state.content_items.extend(content_items); + } + } + async fn finish_exec_tool_call( exec_tool_calls: &Arc>>, exec_id: &str, @@ -404,6 +453,205 @@ impl JsReplManager { } } + fn log_tool_call_response( + req: &RunToolRequest, + ok: bool, + summary: &JsReplToolCallResponseSummary, + response: Option<&JsonValue>, + error: Option<&str>, + ) { + info!( + exec_id = %req.exec_id, + tool_call_id = %req.id, + tool_name = %req.tool_name, + ok, + summary = ?summary, + "js_repl nested tool call completed" + ); + if let Some(response) = response { + trace!( + exec_id = %req.exec_id, + tool_call_id = %req.id, + tool_name = %req.tool_name, + response_json = %response, + "js_repl nested tool call raw response" + ); + } + if let Some(error) = error { + trace!( + exec_id = %req.exec_id, + tool_call_id = %req.id, + tool_name = %req.tool_name, + error = %error, + "js_repl nested tool call raw error" + ); + } + } + + fn summarize_text_payload( + response_type: Option<&str>, + payload_kind: JsReplToolCallPayloadKind, + text: &str, + ) -> JsReplToolCallResponseSummary { + JsReplToolCallResponseSummary { + response_type: response_type.map(str::to_owned), + payload_kind: Some(payload_kind), + payload_text_preview: (!text.is_empty()).then(|| { + truncate_text( + text, + TruncationPolicy::Bytes(JS_REPL_TOOL_RESPONSE_TEXT_PREVIEW_MAX_BYTES), + ) + }), + payload_text_length: Some(text.len()), + ..Default::default() + } + } + + fn summarize_function_output_payload( + response_type: &str, + payload_kind: JsReplToolCallPayloadKind, + output: &FunctionCallOutputPayload, + ) -> JsReplToolCallResponseSummary { + let (payload_item_count, text_item_count, image_item_count) = + if let Some(items) = output.content_items() { + let text_item_count = items + .iter() + .filter(|item| matches!(item, FunctionCallOutputContentItem::InputText { .. })) + .count(); + let image_item_count = items.len().saturating_sub(text_item_count); + ( + Some(items.len()), + Some(text_item_count), + Some(image_item_count), + ) + } else { + (None, None, None) + }; + let payload_text = output.body.to_text(); + JsReplToolCallResponseSummary { + response_type: Some(response_type.to_string()), + payload_kind: Some(payload_kind), + payload_text_preview: payload_text.as_deref().and_then(|text| { + (!text.is_empty()).then(|| { + truncate_text( + text, + TruncationPolicy::Bytes(JS_REPL_TOOL_RESPONSE_TEXT_PREVIEW_MAX_BYTES), + ) + }) + }), + payload_text_length: payload_text.as_ref().map(String::len), + payload_item_count, + text_item_count, + image_item_count, + ..Default::default() + } + } + + fn summarize_message_payload(content: &[ContentItem]) -> JsReplToolCallResponseSummary { + let text_item_count = content + .iter() + .filter(|item| { + matches!( + item, + ContentItem::InputText { .. } | ContentItem::OutputText { .. } + ) + }) + .count(); + let image_item_count = content.len().saturating_sub(text_item_count); + let payload_text = content + .iter() + .filter_map(|item| match item { + ContentItem::InputText { text } | ContentItem::OutputText { text } + if !text.trim().is_empty() => + { + Some(text.as_str()) + } + ContentItem::InputText { .. } + | ContentItem::InputImage { .. } + | ContentItem::OutputText { .. } => None, + }) + .collect::>(); + let payload_text = if payload_text.is_empty() { + None + } else { + Some(payload_text.join("\n")) + }; + JsReplToolCallResponseSummary { + response_type: Some("message".to_string()), + payload_kind: Some(JsReplToolCallPayloadKind::MessageContent), + payload_text_preview: payload_text.as_deref().and_then(|text| { + (!text.is_empty()).then(|| { + truncate_text( + text, + TruncationPolicy::Bytes(JS_REPL_TOOL_RESPONSE_TEXT_PREVIEW_MAX_BYTES), + ) + }) + }), + payload_text_length: payload_text.as_ref().map(String::len), + payload_item_count: Some(content.len()), + text_item_count: Some(text_item_count), + image_item_count: Some(image_item_count), + ..Default::default() + } + } + + fn summarize_tool_call_response(response: &ResponseInputItem) -> JsReplToolCallResponseSummary { + match response { + ResponseInputItem::Message { content, .. } => Self::summarize_message_payload(content), + ResponseInputItem::FunctionCallOutput { output, .. } => { + let payload_kind = if output.content_items().is_some() { + JsReplToolCallPayloadKind::FunctionContentItems + } else { + JsReplToolCallPayloadKind::FunctionText + }; + Self::summarize_function_output_payload( + "function_call_output", + payload_kind, + output, + ) + } + ResponseInputItem::CustomToolCallOutput { output, .. } => { + let payload_kind = if output.content_items().is_some() { + JsReplToolCallPayloadKind::CustomContentItems + } else { + JsReplToolCallPayloadKind::CustomText + }; + Self::summarize_function_output_payload( + "custom_tool_call_output", + payload_kind, + output, + ) + } + ResponseInputItem::McpToolCallOutput { result, .. } => match result { + Ok(result) => { + let output = FunctionCallOutputPayload::from(result); + let mut summary = Self::summarize_function_output_payload( + "mcp_tool_call_output", + JsReplToolCallPayloadKind::McpResult, + &output, + ); + summary.payload_item_count = Some(result.content.len()); + summary.structured_content_present = Some(result.structured_content.is_some()); + summary.result_is_error = Some(result.is_error.unwrap_or(false)); + summary + } + Err(error) => { + let mut summary = Self::summarize_text_payload( + Some("mcp_tool_call_output"), + JsReplToolCallPayloadKind::McpErrorResult, + error, + ); + summary.result_is_error = Some(true); + summary + } + }, + } + } + + fn summarize_tool_call_error(error: &str) -> JsReplToolCallResponseSummary { + Self::summarize_text_payload(None, JsReplToolCallPayloadKind::Error, error) + } + pub async fn reset(&self) -> Result<(), FunctionCallError> { let _permit = self.exec_lock.clone().acquire_owned().await.map_err(|_| { FunctionCallError::RespondToModel("js_repl execution unavailable".to_string()) @@ -546,7 +794,13 @@ impl JsReplManager { }; match response { - ExecResultMessage::Ok { output } => Ok(JsExecResult { output }), + ExecResultMessage::Ok { content_items } => { + let (output, content_items) = split_exec_result_content_items(content_items); + Ok(JsExecResult { + output, + content_items, + }) + } ExecResultMessage::Err { message } => Err(FunctionCallError::RespondToModel(message)), } } @@ -613,6 +867,8 @@ impl JsReplManager { enforce_managed_network: has_managed_network_requirements, network: None, sandbox_policy_cwd: &turn.cwd, + #[cfg(target_os = "macos")] + macos_seatbelt_profile_extensions: None, codex_linux_sandbox_exe: turn.codex_linux_sandbox_exe.as_ref(), use_linux_sandbox_bwrap: turn .features @@ -848,10 +1104,22 @@ impl JsReplManager { error, } => { JsReplManager::wait_for_exec_tool_calls_map(&exec_tool_calls, &id).await; + let content_items = { + let calls = exec_tool_calls.lock().await; + calls + .get(&id) + .map(|state| state.content_items.clone()) + .unwrap_or_default() + }; let mut pending = pending_execs.lock().await; if let Some(tx) = pending.remove(&id) { let payload = if ok { - ExecResultMessage::Ok { output } + ExecResultMessage::Ok { + content_items: build_exec_result_content_items( + output, + content_items, + ), + } } else { ExecResultMessage::Err { message: error @@ -908,7 +1176,11 @@ impl JsReplManager { response: None, error: Some("js_repl execution reset".to_string()), }, - result = JsReplManager::run_tool_request(ctx, req) => result, + result = JsReplManager::run_tool_request( + ctx, + req, + Arc::clone(&exec_tool_calls_for_task), + ) => result, } } None => RunToolResult { @@ -1002,13 +1274,20 @@ impl JsReplManager { } } - async fn run_tool_request(exec: ExecContext, req: RunToolRequest) -> RunToolResult { + async fn run_tool_request( + exec: ExecContext, + req: RunToolRequest, + exec_tool_calls: Arc>>, + ) -> RunToolResult { if is_js_repl_internal_tool(&req.tool_name) { + let error = "js_repl cannot invoke itself".to_string(); + let summary = Self::summarize_tool_call_error(&error); + Self::log_tool_call_response(&req, false, &summary, None, Some(&error)); return RunToolResult { id: req.id, ok: false, response: None, - error: Some("js_repl cannot invoke itself".to_string()), + error: Some(error), }; } @@ -1072,62 +1351,50 @@ impl JsReplManager { .await { Ok(response) => { - if let ResponseInputItem::FunctionCallOutput { output, .. } = &response - && let Some(items) = output.content_items() - { - let mut has_image = false; - let mut content = Vec::with_capacity(items.len()); - for item in items { - match item { - FunctionCallOutputContentItem::InputText { text } => { - content.push(ContentItem::InputText { text: text.clone() }); - } - FunctionCallOutputContentItem::InputImage { image_url } => { - has_image = true; - content.push(ContentItem::InputImage { - image_url: image_url.clone(), - }); - } + if let Some(items) = response_content_items(&response) { + Self::record_exec_tool_call_content_items( + &exec_tool_calls, + &req.exec_id, + items, + ) + .await; + } + + let summary = Self::summarize_tool_call_response(&response); + match serde_json::to_value(response) { + Ok(value) => { + Self::log_tool_call_response(&req, true, &summary, Some(&value), None); + RunToolResult { + id: req.id, + ok: true, + response: Some(value), + error: None, } } - - if has_image - && session - .inject_response_items(vec![ResponseInputItem::Message { - role: "user".to_string(), - content, - }]) - .await - .is_err() - { - warn!( - tool_name = %tool_name, - "js_repl tool call returned image content but there was no active turn to attach it to" - ); + Err(err) => { + let error = format!("failed to serialize tool output: {err}"); + let summary = Self::summarize_tool_call_error(&error); + Self::log_tool_call_response(&req, false, &summary, None, Some(&error)); + RunToolResult { + id: req.id, + ok: false, + response: None, + error: Some(error), + } } } - - match serde_json::to_value(response) { - Ok(value) => RunToolResult { - id: req.id, - ok: true, - response: Some(value), - error: None, - }, - Err(err) => RunToolResult { - id: req.id, - ok: false, - response: None, - error: Some(format!("failed to serialize tool output: {err}")), - }, + } + Err(err) => { + let error = err.to_string(); + let summary = Self::summarize_tool_call_error(&error); + Self::log_tool_call_response(&req, false, &summary, None, Some(&error)); + RunToolResult { + id: req.id, + ok: false, + response: None, + error: Some(error), } } - Err(err) => RunToolResult { - id: req.id, - ok: false, - response: None, - error: Some(err.to_string()), - }, } } @@ -1165,6 +1432,50 @@ impl JsReplManager { } } +fn response_content_items( + response: &ResponseInputItem, +) -> Option> { + match response { + ResponseInputItem::FunctionCallOutput { output, .. } + | ResponseInputItem::CustomToolCallOutput { output, .. } => output + .content_items() + .map(<[FunctionCallOutputContentItem]>::to_vec), + ResponseInputItem::McpToolCallOutput { result, .. } => match result { + Ok(result) => FunctionCallOutputPayload::from(result) + .content_items() + .map(<[FunctionCallOutputContentItem]>::to_vec), + Err(_) => None, + }, + ResponseInputItem::Message { .. } => None, + } +} + +fn build_exec_result_content_items( + output: String, + content_items: Vec, +) -> Vec { + let mut all_content_items = Vec::with_capacity(content_items.len() + 1); + all_content_items.push(FunctionCallOutputContentItem::InputText { text: output }); + all_content_items.extend(content_items); + all_content_items +} + +fn split_exec_result_content_items( + mut content_items: Vec, +) -> (String, Vec) { + match content_items.first() { + Some(FunctionCallOutputContentItem::InputText { .. }) => { + let FunctionCallOutputContentItem::InputText { text } = content_items.remove(0) else { + unreachable!("first content item should be input_text"); + }; + (text, content_items) + } + Some(FunctionCallOutputContentItem::InputImage { .. }) | None => { + (String::new(), content_items) + } + } +} + fn is_freeform_tool(specs: &[ToolSpec], name: &str) -> bool { specs .iter() @@ -1220,8 +1531,12 @@ struct RunToolResult { #[derive(Debug)] enum ExecResultMessage { - Ok { output: String }, - Err { message: String }, + Ok { + content_items: Vec, + }, + Err { + message: String, + }, } #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] @@ -1359,7 +1674,8 @@ mod tests { use codex_protocol::dynamic_tools::DynamicToolCallOutputContentItem; use codex_protocol::dynamic_tools::DynamicToolResponse; use codex_protocol::dynamic_tools::DynamicToolSpec; - use codex_protocol::models::ContentItem; + use codex_protocol::models::FunctionCallOutputContentItem; + use codex_protocol::models::FunctionCallOutputPayload; use codex_protocol::models::ResponseInputItem; use codex_protocol::openai_models::InputModality; use pretty_assertions::assert_eq; @@ -1577,6 +1893,84 @@ mod tests { ); } + #[test] + fn summarize_tool_call_response_for_multimodal_function_output() { + let response = ResponseInputItem::FunctionCallOutput { + call_id: "call-1".to_string(), + output: FunctionCallOutputPayload::from_content_items(vec![ + FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,abcd".to_string(), + }, + ]), + }; + + let actual = JsReplManager::summarize_tool_call_response(&response); + + assert_eq!( + actual, + JsReplToolCallResponseSummary { + response_type: Some("function_call_output".to_string()), + payload_kind: Some(JsReplToolCallPayloadKind::FunctionContentItems), + payload_text_preview: None, + payload_text_length: None, + payload_item_count: Some(1), + text_item_count: Some(0), + image_item_count: Some(1), + structured_content_present: None, + result_is_error: None, + } + ); + } + + #[test] + fn summarize_tool_call_response_for_multimodal_custom_output() { + let response = ResponseInputItem::CustomToolCallOutput { + call_id: "call-1".to_string(), + output: FunctionCallOutputPayload::from_content_items(vec![ + FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,abcd".to_string(), + }, + ]), + }; + + let actual = JsReplManager::summarize_tool_call_response(&response); + + assert_eq!( + actual, + JsReplToolCallResponseSummary { + response_type: Some("custom_tool_call_output".to_string()), + payload_kind: Some(JsReplToolCallPayloadKind::CustomContentItems), + payload_text_preview: None, + payload_text_length: None, + payload_item_count: Some(1), + text_item_count: Some(0), + image_item_count: Some(1), + structured_content_present: None, + result_is_error: None, + } + ); + } + + #[test] + fn summarize_tool_call_error_marks_error_payload() { + let actual = JsReplManager::summarize_tool_call_error("tool failed"); + + assert_eq!( + actual, + JsReplToolCallResponseSummary { + response_type: None, + payload_kind: Some(JsReplToolCallPayloadKind::Error), + payload_text_preview: Some("tool failed".to_string()), + payload_text_length: Some("tool failed".len()), + payload_item_count: None, + text_item_count: None, + image_item_count: None, + structured_content_present: None, + result_is_error: None, + } + ); + } + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn reset_clears_inflight_exec_tool_calls_without_waiting() { let manager = JsReplManager::new(None, Vec::new()) @@ -2017,20 +2411,22 @@ console.log(out.output?.body?.text ?? ""); ) .await?; assert!(result.output.contains("function_call_output")); - - let pending_input = session.get_pending_input().await; - let [ResponseInputItem::Message { role, content }] = pending_input.as_slice() else { - panic!( - "view_image should inject exactly one pending input message, got {pending_input:?}" - ); - }; - assert_eq!(role, "user"); - let [ContentItem::InputImage { image_url }] = content.as_slice() else { - panic!( - "view_image should inject exactly one input_image content item, got {content:?}" - ); + assert_eq!( + result.content_items.as_slice(), + [FunctionCallOutputContentItem::InputImage { + image_url: + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==" + .to_string(), + }] + .as_slice() + ); + let [FunctionCallOutputContentItem::InputImage { image_url }] = + result.content_items.as_slice() + else { + panic!("view_image should return exactly one input_image content item"); }; assert!(image_url.starts_with("data:image/png;base64,")); + assert!(session.get_pending_input().await.is_empty()); Ok(()) } @@ -2111,22 +2507,18 @@ console.log(out.type); response_watcher_result?; let result = result?; assert!(result.output.contains("function_call_output")); - - let pending_input = session.get_pending_input().await; assert_eq!( - pending_input, - vec![ResponseInputItem::Message { - role: "user".to_string(), - content: vec![ - ContentItem::InputText { - text: "inline image note".to_string(), - }, - ContentItem::InputImage { - image_url: image_url.to_string(), - }, - ], - }] + result.content_items, + vec![ + FunctionCallOutputContentItem::InputText { + text: "inline image note".to_string(), + }, + FunctionCallOutputContentItem::InputImage { + image_url: image_url.to_string(), + }, + ] ); + assert!(session.get_pending_input().await.is_empty()); Ok(()) } diff --git a/codex-rs/core/src/tools/parallel.rs b/codex-rs/core/src/tools/parallel.rs index 04d505c5841..5f2fc89e5fd 100644 --- a/codex-rs/core/src/tools/parallel.rs +++ b/codex-rs/core/src/tools/parallel.rs @@ -117,7 +117,10 @@ impl ToolCallRuntime { match &call.payload { ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput { call_id: call.call_id.clone(), - output: Self::abort_message(call, secs), + output: FunctionCallOutputPayload { + body: FunctionCallOutputBody::Text(Self::abort_message(call, secs)), + ..Default::default() + }, }, ToolPayload::Mcp { .. } => ResponseInputItem::McpToolCallOutput { call_id: call.call_id.clone(), diff --git a/codex-rs/core/src/tools/registry.rs b/codex-rs/core/src/tools/registry.rs index 7a0be68df3f..3c021cc582c 100644 --- a/codex-rs/core/src/tools/registry.rs +++ b/codex-rs/core/src/tools/registry.rs @@ -379,6 +379,7 @@ async fn dispatch_after_tool_use_hook( .dispatch(HookPayload { session_id: session.conversation_id, cwd: turn.cwd.clone(), + client: turn.app_server_client_name.clone(), triggered_at: chrono::Utc::now(), hook_event: HookEvent::AfterToolUse { event: HookEventAfterToolUse { diff --git a/codex-rs/core/src/tools/router.rs b/codex-rs/core/src/tools/router.rs index 4897b4ea6ff..a55fb5fd5a6 100644 --- a/codex-rs/core/src/tools/router.rs +++ b/codex-rs/core/src/tools/router.rs @@ -197,7 +197,10 @@ impl ToolRouter { if payload_outputs_custom { ResponseInputItem::CustomToolCallOutput { call_id, - output: message, + output: codex_protocol::models::FunctionCallOutputPayload { + body: FunctionCallOutputBody::Text(message), + success: Some(false), + }, } } else { ResponseInputItem::FunctionCallOutput { diff --git a/codex-rs/core/src/tools/runtimes/shell/unix_escalation.rs b/codex-rs/core/src/tools/runtimes/shell/unix_escalation.rs index 5bcf809e213..ae1df34ece7 100644 --- a/codex-rs/core/src/tools/runtimes/shell/unix_escalation.rs +++ b/codex-rs/core/src/tools/runtimes/shell/unix_escalation.rs @@ -12,12 +12,14 @@ use crate::skills::SkillMetadata; use crate::tools::runtimes::ExecveSessionApproval; use crate::tools::runtimes::build_command_spec; use crate::tools::sandboxing::SandboxAttempt; +use crate::tools::sandboxing::SandboxablePreference; use crate::tools::sandboxing::ToolCtx; use crate::tools::sandboxing::ToolError; use codex_execpolicy::Decision; use codex_execpolicy::Policy; use codex_execpolicy::RuleMatch; use codex_protocol::config_types::WindowsSandboxLevel; +use codex_protocol::models::MacOsSeatbeltProfileExtensions; use codex_protocol::models::PermissionProfile; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::NetworkPolicyRuleAction; @@ -26,11 +28,15 @@ use codex_protocol::protocol::ReviewDecision; use codex_protocol::protocol::SandboxPolicy; use codex_shell_command::bash::parse_shell_lc_plain_commands; use codex_shell_command::bash::parse_shell_lc_single_command_prefix; -use codex_shell_escalation::EscalateAction; use codex_shell_escalation::EscalateServer; +use codex_shell_escalation::EscalationDecision; +use codex_shell_escalation::EscalationExecution; +use codex_shell_escalation::EscalationPermissions; use codex_shell_escalation::EscalationPolicy; use codex_shell_escalation::ExecParams; use codex_shell_escalation::ExecResult; +use codex_shell_escalation::Permissions as EscalatedPermissions; +use codex_shell_escalation::PreparedExec; use codex_shell_escalation::ShellCommandExecutor; use codex_shell_escalation::Stopwatch; use codex_utils_absolute_path::AbsolutePathBuf; @@ -105,6 +111,15 @@ pub(super) async fn try_run_zsh_fork( sandbox_permissions, justification, arg0, + sandbox_policy_cwd: ctx.turn.cwd.clone(), + macos_seatbelt_profile_extensions: ctx + .turn + .config + .permissions + .macos_seatbelt_profile_extensions + .clone(), + codex_linux_sandbox_exe: ctx.turn.codex_linux_sandbox_exe.clone(), + use_linux_sandbox_bwrap: ctx.turn.features.enabled(Feature::UseLinuxSandboxBwrap), }; let main_execve_wrapper_exe = ctx .session @@ -136,6 +151,7 @@ pub(super) async fn try_run_zsh_fork( approval_policy: ctx.turn.approval_policy.value(), sandbox_policy: attempt.policy.clone(), sandbox_permissions: req.sandbox_permissions, + prompt_permissions: req.additional_permissions.clone(), stopwatch: stopwatch.clone(), }; @@ -146,7 +162,7 @@ pub(super) async fn try_run_zsh_fork( ); let exec_result = escalate_server - .exec(exec_params, cancel_token, &command_executor) + .exec(exec_params, cancel_token, Arc::new(command_executor)) .await .map_err(|err| ToolError::Rejected(err.to_string()))?; @@ -161,6 +177,7 @@ struct CoreShellActionProvider { approval_policy: AskForApproval, sandbox_policy: SandboxPolicy, sandbox_permissions: SandboxPermissions, + prompt_permissions: Option, stopwatch: Stopwatch, } @@ -182,6 +199,55 @@ impl CoreShellActionProvider { }) } + fn shell_request_escalation_execution( + sandbox_permissions: SandboxPermissions, + sandbox_policy: &SandboxPolicy, + additional_permissions: Option<&PermissionProfile>, + macos_seatbelt_profile_extensions: Option<&MacOsSeatbeltProfileExtensions>, + ) -> EscalationExecution { + match sandbox_permissions { + SandboxPermissions::UseDefault => EscalationExecution::TurnDefault, + SandboxPermissions::RequireEscalated => EscalationExecution::Unsandboxed, + SandboxPermissions::WithAdditionalPermissions => additional_permissions + .map(|_| { + // Shell request additional permissions were already normalized and + // merged into the first-attempt sandbox policy. + EscalationExecution::Permissions(EscalationPermissions::Permissions( + EscalatedPermissions { + sandbox_policy: sandbox_policy.clone(), + macos_seatbelt_profile_extensions: macos_seatbelt_profile_extensions + .cloned(), + }, + )) + }) + .unwrap_or(EscalationExecution::TurnDefault), + } + } + + fn skill_escalation_execution(skill: &SkillMetadata) -> EscalationExecution { + skill + .permissions + .as_ref() + .map(|permissions| { + EscalationExecution::Permissions(EscalationPermissions::Permissions( + EscalatedPermissions { + sandbox_policy: permissions.sandbox_policy.get().clone(), + macos_seatbelt_profile_extensions: permissions + .macos_seatbelt_profile_extensions + .clone(), + }, + )) + }) + .or_else(|| { + skill + .permission_profile + .clone() + .map(EscalationPermissions::PermissionProfile) + .map(EscalationExecution::Permissions) + }) + .unwrap_or(EscalationExecution::TurnDefault) + } + async fn prompt( &self, program: &AbsolutePathBuf, @@ -265,22 +331,21 @@ impl CoreShellActionProvider { program: &AbsolutePathBuf, argv: &[String], workdir: &AbsolutePathBuf, - additional_permissions: Option, + prompt_permissions: Option, + escalation_execution: EscalationExecution, decision_source: DecisionSource, - ) -> anyhow::Result { + ) -> anyhow::Result { let action = match decision { - Decision::Forbidden => EscalateAction::Deny { - reason: Some("Execution forbidden by policy".to_string()), - }, + Decision::Forbidden => { + EscalationDecision::deny(Some("Execution forbidden by policy".to_string())) + } Decision::Prompt => { if matches!( self.approval_policy, AskForApproval::Never | AskForApproval::Reject(RejectConfig { rules: true, .. }) ) { - EscalateAction::Deny { - reason: Some("Execution forbidden by policy".to_string()), - } + EscalationDecision::deny(Some("Execution forbidden by policy".to_string())) } else { match self .prompt( @@ -288,7 +353,7 @@ impl CoreShellActionProvider { argv, workdir, &self.stopwatch, - additional_permissions, + prompt_permissions, &decision_source, ) .await? @@ -296,9 +361,9 @@ impl CoreShellActionProvider { ReviewDecision::Approved | ReviewDecision::ApprovedExecpolicyAmendment { .. } => { if needs_escalation { - EscalateAction::Escalate + EscalationDecision::escalate(escalation_execution.clone()) } else { - EscalateAction::Run + EscalationDecision::run() } } ReviewDecision::ApprovedForSession => { @@ -323,9 +388,9 @@ impl CoreShellActionProvider { } if needs_escalation { - EscalateAction::Escalate + EscalationDecision::escalate(escalation_execution.clone()) } else { - EscalateAction::Run + EscalationDecision::run() } } ReviewDecision::NetworkPolicyAmendment { @@ -333,29 +398,29 @@ impl CoreShellActionProvider { } => match network_policy_amendment.action { NetworkPolicyRuleAction::Allow => { if needs_escalation { - EscalateAction::Escalate + EscalationDecision::escalate(escalation_execution.clone()) } else { - EscalateAction::Run + EscalationDecision::run() } } - NetworkPolicyRuleAction::Deny => EscalateAction::Deny { - reason: Some("User denied execution".to_string()), - }, - }, - ReviewDecision::Denied => EscalateAction::Deny { - reason: Some("User denied execution".to_string()), - }, - ReviewDecision::Abort => EscalateAction::Deny { - reason: Some("User cancelled execution".to_string()), + NetworkPolicyRuleAction::Deny => { + EscalationDecision::deny(Some("User denied execution".to_string())) + } }, + ReviewDecision::Denied => { + EscalationDecision::deny(Some("User denied execution".to_string())) + } + ReviewDecision::Abort => { + EscalationDecision::deny(Some("User cancelled execution".to_string())) + } } } } Decision::Allow => { if needs_escalation { - EscalateAction::Escalate + EscalationDecision::escalate(escalation_execution) } else { - EscalateAction::Run + EscalationDecision::run() } } }; @@ -373,7 +438,7 @@ impl EscalationPolicy for CoreShellActionProvider { program: &AbsolutePathBuf, argv: &[String], workdir: &AbsolutePathBuf, - ) -> anyhow::Result { + ) -> anyhow::Result { tracing::debug!( "Determining escalation action for command {program:?} with args {argv:?} in {workdir:?}" ); @@ -394,15 +459,13 @@ impl EscalationPolicy for CoreShellActionProvider { tracing::debug!( "Found session approval for {program:?}, allowing execution without further checks" ); - // TODO(mbolin): We need to include the permissions with the - // escalation decision so it can be run with the appropriate - // permissions. - let _permissions = approval + let execution = approval .skill .as_ref() - .and_then(|s| s.permission_profile.clone()); + .map(Self::skill_escalation_execution) + .unwrap_or(EscalationExecution::TurnDefault); - return Ok(EscalateAction::Escalate); + return Ok(EscalationDecision::escalate(execution)); } // In the usual case, the execve wrapper reports the command being @@ -424,6 +487,7 @@ impl EscalationPolicy for CoreShellActionProvider { argv, workdir, skill.permission_profile.clone(), + Self::skill_escalation_execution(&skill), decision_source, ) .await; @@ -464,13 +528,24 @@ impl EscalationPolicy for CoreShellActionProvider { } else { DecisionSource::UnmatchedCommandFallback }; + let escalation_execution = Self::shell_request_escalation_execution( + self.sandbox_permissions, + &self.sandbox_policy, + self.prompt_permissions.as_ref(), + self.turn + .config + .permissions + .macos_seatbelt_profile_extensions + .as_ref(), + ); self.process_decision( evaluation.decision, needs_escalation, program, argv, workdir, - None, + self.prompt_permissions.clone(), + escalation_execution, decision_source, ) .await @@ -488,6 +563,10 @@ struct CoreShellCommandExecutor { sandbox_permissions: SandboxPermissions, justification: Option, arg0: Option, + sandbox_policy_cwd: PathBuf, + macos_seatbelt_profile_extensions: Option, + codex_linux_sandbox_exe: Option, + use_linux_sandbox_bwrap: bool, } #[async_trait::async_trait] @@ -533,6 +612,126 @@ impl ShellCommandExecutor for CoreShellCommandExecutor { timed_out: result.timed_out, }) } + + async fn prepare_escalated_exec( + &self, + program: &AbsolutePathBuf, + argv: &[String], + workdir: &AbsolutePathBuf, + env: HashMap, + execution: EscalationExecution, + ) -> anyhow::Result { + let command = join_program_and_argv(program, argv); + let Some(first_arg) = argv.first() else { + return Err(anyhow::anyhow!( + "intercepted exec request must contain argv[0]" + )); + }; + + let prepared = match execution { + EscalationExecution::Unsandboxed => PreparedExec { + command, + cwd: workdir.to_path_buf(), + env, + arg0: Some(first_arg.clone()), + }, + EscalationExecution::TurnDefault => self.prepare_sandboxed_exec( + command, + workdir, + env, + &self.sandbox_policy, + None, + self.macos_seatbelt_profile_extensions.as_ref(), + )?, + EscalationExecution::Permissions(EscalationPermissions::PermissionProfile( + permission_profile, + )) => self.prepare_sandboxed_exec( + command, + workdir, + env, + &self.sandbox_policy, + Some(permission_profile), + None, + )?, + EscalationExecution::Permissions(EscalationPermissions::Permissions(permissions)) => { + self.prepare_sandboxed_exec( + command, + workdir, + env, + &permissions.sandbox_policy, + None, + permissions.macos_seatbelt_profile_extensions.as_ref(), + )? + } + }; + + Ok(prepared) + } +} + +impl CoreShellCommandExecutor { + fn prepare_sandboxed_exec( + &self, + command: Vec, + workdir: &AbsolutePathBuf, + env: HashMap, + sandbox_policy: &SandboxPolicy, + additional_permissions: Option, + #[cfg(target_os = "macos")] macos_seatbelt_profile_extensions: Option< + &MacOsSeatbeltProfileExtensions, + >, + #[cfg(not(target_os = "macos"))] _macos_seatbelt_profile_extensions: Option< + &MacOsSeatbeltProfileExtensions, + >, + ) -> anyhow::Result { + let (program, args) = command + .split_first() + .ok_or_else(|| anyhow::anyhow!("prepared command must not be empty"))?; + let sandbox_manager = crate::sandboxing::SandboxManager::new(); + let sandbox = sandbox_manager.select_initial( + sandbox_policy, + SandboxablePreference::Auto, + self.windows_sandbox_level, + self.network.is_some(), + ); + let mut exec_request = + sandbox_manager.transform(crate::sandboxing::SandboxTransformRequest { + spec: crate::sandboxing::CommandSpec { + program: program.clone(), + args: args.to_vec(), + cwd: workdir.to_path_buf(), + env, + expiration: ExecExpiration::DefaultTimeout, + sandbox_permissions: if additional_permissions.is_some() { + SandboxPermissions::WithAdditionalPermissions + } else { + SandboxPermissions::UseDefault + }, + additional_permissions, + justification: self.justification.clone(), + }, + policy: sandbox_policy, + sandbox, + enforce_managed_network: self.network.is_some(), + network: self.network.as_ref(), + sandbox_policy_cwd: &self.sandbox_policy_cwd, + #[cfg(target_os = "macos")] + macos_seatbelt_profile_extensions, + codex_linux_sandbox_exe: self.codex_linux_sandbox_exe.as_ref(), + use_linux_sandbox_bwrap: self.use_linux_sandbox_bwrap, + windows_sandbox_level: self.windows_sandbox_level, + })?; + if let Some(network) = exec_request.network.as_ref() { + network.apply_to_env(&mut exec_request.env); + } + + Ok(PreparedExec { + command: exec_request.command, + cwd: exec_request.cwd, + env: exec_request.env, + arg0: exec_request.arg0, + }) + } } #[derive(Debug, Eq, PartialEq)] @@ -600,122 +799,5 @@ fn join_program_and_argv(program: &AbsolutePathBuf, argv: &[String]) -> Vec reason, - _ => "".to_string(), - }, - "unexpected shell command format for zsh-fork execution" - ); - } - - #[test] - fn join_program_and_argv_replaces_original_argv_zero() { - assert_eq!( - join_program_and_argv( - &AbsolutePathBuf::from_absolute_path("/tmp/tool").unwrap(), - &["./tool".into(), "--flag".into(), "value".into()], - ), - vec!["/tmp/tool", "--flag", "value"] - ); - assert_eq!( - join_program_and_argv( - &AbsolutePathBuf::from_absolute_path("/tmp/tool").unwrap(), - &["./tool".into()] - ), - vec!["/tmp/tool"] - ); - } - - #[test] - fn map_exec_result_preserves_stdout_and_stderr() { - let out = map_exec_result( - SandboxType::None, - ExecResult { - exit_code: 0, - stdout: "out".to_string(), - stderr: "err".to_string(), - output: "outerr".to_string(), - duration: Duration::from_millis(1), - timed_out: false, - }, - ) - .unwrap(); - - assert_eq!(out.stdout.text, "out"); - assert_eq!(out.stderr.text, "err"); - assert_eq!(out.aggregated_output.text, "outerr"); - } -} +#[path = "unix_escalation_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/runtimes/shell/unix_escalation_tests.rs b/codex-rs/core/src/tools/runtimes/shell/unix_escalation_tests.rs new file mode 100644 index 00000000000..d2fd6126708 --- /dev/null +++ b/codex-rs/core/src/tools/runtimes/shell/unix_escalation_tests.rs @@ -0,0 +1,320 @@ +use super::CoreShellActionProvider; +#[cfg(target_os = "macos")] +use super::CoreShellCommandExecutor; +use super::ParsedShellCommand; +use super::extract_shell_script; +use super::join_program_and_argv; +use super::map_exec_result; +#[cfg(target_os = "macos")] +use crate::config::Constrained; +#[cfg(target_os = "macos")] +use crate::config::Permissions; +#[cfg(target_os = "macos")] +use crate::config::types::ShellEnvironmentPolicy; +use crate::exec::SandboxType; +#[cfg(target_os = "macos")] +use crate::protocol::AskForApproval; +use crate::protocol::ReadOnlyAccess; +use crate::protocol::SandboxPolicy; +#[cfg(target_os = "macos")] +use crate::sandboxing::SandboxPermissions; +#[cfg(target_os = "macos")] +use crate::seatbelt::MACOS_PATH_TO_SEATBELT_EXECUTABLE; +#[cfg(target_os = "macos")] +use codex_protocol::config_types::WindowsSandboxLevel; +use codex_protocol::models::FileSystemPermissions; +use codex_protocol::models::MacOsPreferencesPermission; +use codex_protocol::models::MacOsSeatbeltProfileExtensions; +use codex_protocol::models::PermissionProfile; +use codex_shell_escalation::EscalationExecution; +use codex_shell_escalation::EscalationPermissions; +use codex_shell_escalation::ExecResult; +use codex_shell_escalation::Permissions as EscalatedPermissions; +#[cfg(target_os = "macos")] +use codex_shell_escalation::ShellCommandExecutor; +use codex_utils_absolute_path::AbsolutePathBuf; +use pretty_assertions::assert_eq; +#[cfg(target_os = "macos")] +use std::collections::HashMap; +use std::path::PathBuf; +use std::time::Duration; + +#[test] +fn extract_shell_script_preserves_login_flag() { + assert_eq!( + extract_shell_script(&["/bin/zsh".into(), "-lc".into(), "echo hi".into()]).unwrap(), + ParsedShellCommand { + script: "echo hi".to_string(), + login: true, + } + ); + assert_eq!( + extract_shell_script(&["/bin/zsh".into(), "-c".into(), "echo hi".into()]).unwrap(), + ParsedShellCommand { + script: "echo hi".to_string(), + login: false, + } + ); +} + +#[test] +fn extract_shell_script_supports_wrapped_command_prefixes() { + assert_eq!( + extract_shell_script(&[ + "/usr/bin/env".into(), + "CODEX_EXECVE_WRAPPER=1".into(), + "/bin/zsh".into(), + "-lc".into(), + "echo hello".into() + ]) + .unwrap(), + ParsedShellCommand { + script: "echo hello".to_string(), + login: true, + } + ); + + assert_eq!( + extract_shell_script(&[ + "sandbox-exec".into(), + "-p".into(), + "sandbox_policy".into(), + "/bin/zsh".into(), + "-c".into(), + "pwd".into(), + ]) + .unwrap(), + ParsedShellCommand { + script: "pwd".to_string(), + login: false, + } + ); +} + +#[test] +fn extract_shell_script_rejects_unsupported_shell_invocation() { + let err = extract_shell_script(&[ + "sandbox-exec".into(), + "-fc".into(), + "echo not supported".into(), + ]) + .unwrap_err(); + assert!(matches!(err, super::ToolError::Rejected(_))); + assert_eq!( + match err { + super::ToolError::Rejected(reason) => reason, + _ => "".to_string(), + }, + "unexpected shell command format for zsh-fork execution" + ); +} + +#[test] +fn join_program_and_argv_replaces_original_argv_zero() { + assert_eq!( + join_program_and_argv( + &AbsolutePathBuf::from_absolute_path("/tmp/tool").unwrap(), + &["./tool".into(), "--flag".into(), "value".into()], + ), + vec!["/tmp/tool", "--flag", "value"] + ); + assert_eq!( + join_program_and_argv( + &AbsolutePathBuf::from_absolute_path("/tmp/tool").unwrap(), + &["./tool".into()] + ), + vec!["/tmp/tool"] + ); +} + +#[test] +fn map_exec_result_preserves_stdout_and_stderr() { + let out = map_exec_result( + SandboxType::None, + ExecResult { + exit_code: 0, + stdout: "out".to_string(), + stderr: "err".to_string(), + output: "outerr".to_string(), + duration: Duration::from_millis(1), + timed_out: false, + }, + ) + .unwrap(); + + assert_eq!(out.stdout.text, "out"); + assert_eq!(out.stderr.text, "err"); + assert_eq!(out.aggregated_output.text, "outerr"); +} + +#[test] +fn shell_request_escalation_execution_is_explicit() { + let requested_permissions = PermissionProfile { + file_system: Some(FileSystemPermissions { + read: None, + write: Some(vec![PathBuf::from("./output")]), + }), + ..Default::default() + }; + let sandbox_policy = SandboxPolicy::WorkspaceWrite { + writable_roots: vec![AbsolutePathBuf::from_absolute_path("/tmp/original/output").unwrap()], + read_only_access: ReadOnlyAccess::FullAccess, + network_access: false, + exclude_tmpdir_env_var: false, + exclude_slash_tmp: false, + }; + let macos_seatbelt_profile_extensions = MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadWrite, + ..Default::default() + }; + + assert_eq!( + CoreShellActionProvider::shell_request_escalation_execution( + crate::sandboxing::SandboxPermissions::UseDefault, + &sandbox_policy, + None, + Some(&macos_seatbelt_profile_extensions), + ), + EscalationExecution::TurnDefault, + ); + assert_eq!( + CoreShellActionProvider::shell_request_escalation_execution( + crate::sandboxing::SandboxPermissions::RequireEscalated, + &sandbox_policy, + None, + Some(&macos_seatbelt_profile_extensions), + ), + EscalationExecution::Unsandboxed, + ); + assert_eq!( + CoreShellActionProvider::shell_request_escalation_execution( + crate::sandboxing::SandboxPermissions::WithAdditionalPermissions, + &sandbox_policy, + Some(&requested_permissions), + Some(&macos_seatbelt_profile_extensions), + ), + EscalationExecution::Permissions(EscalationPermissions::Permissions( + EscalatedPermissions { + sandbox_policy, + macos_seatbelt_profile_extensions: Some(macos_seatbelt_profile_extensions), + }, + )), + ); +} + +#[cfg(target_os = "macos")] +#[tokio::test] +async fn prepare_escalated_exec_turn_default_preserves_macos_seatbelt_extensions() { + let cwd = AbsolutePathBuf::from_absolute_path(std::env::temp_dir()).unwrap(); + let executor = CoreShellCommandExecutor { + command: vec!["echo".to_string(), "ok".to_string()], + cwd: cwd.to_path_buf(), + env: HashMap::new(), + network: None, + sandbox: SandboxType::None, + sandbox_policy: SandboxPolicy::new_read_only_policy(), + windows_sandbox_level: WindowsSandboxLevel::Disabled, + sandbox_permissions: SandboxPermissions::UseDefault, + justification: None, + arg0: None, + sandbox_policy_cwd: cwd.to_path_buf(), + macos_seatbelt_profile_extensions: Some(MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadWrite, + ..Default::default() + }), + codex_linux_sandbox_exe: None, + use_linux_sandbox_bwrap: false, + }; + + let prepared = executor + .prepare_escalated_exec( + &AbsolutePathBuf::from_absolute_path("/bin/echo").unwrap(), + &["echo".to_string(), "ok".to_string()], + &cwd, + HashMap::new(), + EscalationExecution::TurnDefault, + ) + .await + .unwrap(); + + assert_eq!( + prepared.command.first().map(String::as_str), + Some(MACOS_PATH_TO_SEATBELT_EXECUTABLE) + ); + assert_eq!(prepared.command.get(1).map(String::as_str), Some("-p")); + assert!( + prepared + .command + .get(2) + .is_some_and(|policy| policy.contains("(allow user-preference-write)")), + "expected seatbelt policy to include macOS extension profile: {:?}", + prepared.command + ); +} + +#[cfg(target_os = "macos")] +#[tokio::test] +async fn prepare_escalated_exec_permissions_preserve_macos_seatbelt_extensions() { + let cwd = AbsolutePathBuf::from_absolute_path(std::env::temp_dir()).unwrap(); + let executor = CoreShellCommandExecutor { + command: vec!["echo".to_string(), "ok".to_string()], + cwd: cwd.to_path_buf(), + env: HashMap::new(), + network: None, + sandbox: SandboxType::None, + sandbox_policy: SandboxPolicy::DangerFullAccess, + windows_sandbox_level: WindowsSandboxLevel::Disabled, + sandbox_permissions: SandboxPermissions::UseDefault, + justification: None, + arg0: None, + sandbox_policy_cwd: cwd.to_path_buf(), + macos_seatbelt_profile_extensions: None, + codex_linux_sandbox_exe: None, + use_linux_sandbox_bwrap: false, + }; + + let permissions = Permissions { + approval_policy: Constrained::allow_any(AskForApproval::Never), + sandbox_policy: Constrained::allow_any(SandboxPolicy::new_read_only_policy()), + network: None, + allow_login_shell: true, + shell_environment_policy: ShellEnvironmentPolicy::default(), + windows_sandbox_mode: None, + macos_seatbelt_profile_extensions: Some(MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadWrite, + ..Default::default() + }), + }; + + let prepared = executor + .prepare_escalated_exec( + &AbsolutePathBuf::from_absolute_path("/bin/echo").unwrap(), + &["echo".to_string(), "ok".to_string()], + &cwd, + HashMap::new(), + EscalationExecution::Permissions(EscalationPermissions::Permissions( + EscalatedPermissions { + sandbox_policy: permissions.sandbox_policy.get().clone(), + macos_seatbelt_profile_extensions: permissions + .macos_seatbelt_profile_extensions + .clone(), + }, + )), + ) + .await + .unwrap(); + + assert_eq!( + prepared.command.first().map(String::as_str), + Some(MACOS_PATH_TO_SEATBELT_EXECUTABLE) + ); + assert_eq!(prepared.command.get(1).map(String::as_str), Some("-p")); + assert!( + prepared + .command + .get(2) + .is_some_and(|policy| policy.contains("(allow user-preference-write)")), + "expected seatbelt policy to include macOS extension profile: {:?}", + prepared.command + ); +} diff --git a/codex-rs/core/src/tools/sandboxing.rs b/codex-rs/core/src/tools/sandboxing.rs index 25ea1015921..28d87b5bf3c 100644 --- a/codex-rs/core/src/tools/sandboxing.rs +++ b/codex-rs/core/src/tools/sandboxing.rs @@ -340,6 +340,8 @@ impl<'a> SandboxAttempt<'a> { enforce_managed_network: self.enforce_managed_network, network, sandbox_policy_cwd: self.sandbox_cwd, + #[cfg(target_os = "macos")] + macos_seatbelt_profile_extensions: None, codex_linux_sandbox_exe: self.codex_linux_sandbox_exe, use_linux_sandbox_bwrap: self.use_linux_sandbox_bwrap, windows_sandbox_level: self.windows_sandbox_level, diff --git a/codex-rs/core/src/tools/spec.rs b/codex-rs/core/src/tools/spec.rs index 47682cf0e79..76730d0890d 100644 --- a/codex-rs/core/src/tools/spec.rs +++ b/codex-rs/core/src/tools/spec.rs @@ -633,6 +633,15 @@ fn create_spawn_agent_tool(config: &ToolsConfig) -> ToolSpec { )), }, ), + ( + "fork_context".to_string(), + JsonSchema::Boolean { + description: Some( + "When true, fork the current thread history into the new agent before sending the initial prompt. This must be used when you want the new agent to have exactly the same context as you." + .to_string(), + ), + }, + ), ]); ToolSpec::Function(ResponsesApiTool { diff --git a/codex-rs/core/tests/common/context_snapshot.rs b/codex-rs/core/tests/common/context_snapshot.rs index 24442dd4bd7..addb1bc31ed 100644 --- a/codex-rs/core/tests/common/context_snapshot.rs +++ b/codex-rs/core/tests/common/context_snapshot.rs @@ -238,15 +238,34 @@ fn canonicalize_snapshot_text(text: &str) -> String { return "".to_string(); } if text.starts_with("") { + let subagent_count = text + .split_once("") + .and_then(|(_, rest)| rest.split_once("")) + .map(|(subagents, _)| { + subagents + .lines() + .filter(|line| line.trim_start().starts_with("- ")) + .count() + }) + .unwrap_or(0); + let subagents_suffix = if subagent_count > 0 { + format!(":subagents={subagent_count}") + } else { + String::new() + }; if let (Some(cwd_start), Some(cwd_end)) = (text.find(""), text.find("")) { let cwd = &text[cwd_start + "".len()..cwd_end]; return if cwd.ends_with("PRETURN_CONTEXT_DIFF_CWD") { - "".to_string() + format!("") } else { - ">".to_string() + format!("{subagents_suffix}>") }; } - return "".to_string(); + return if subagent_count > 0 { + format!("") + } else { + "".to_string() + }; } if text.starts_with("You are performing a CONTEXT CHECKPOINT COMPACTION.") { return "".to_string(); @@ -308,6 +327,28 @@ mod tests { assert_eq!(rendered, "00:message/user:"); } + #[test] + fn redacted_text_mode_normalizes_environment_context_with_subagents() { + let items = vec![json!({ + "type": "message", + "role": "user", + "content": [{ + "type": "input_text", + "text": "\n /tmp/example\n bash\n \n - agent-1: atlas\n - agent-2\n \n" + }] + })]; + + let rendered = format_response_items_snapshot( + &items, + &ContextSnapshotOptions::default().render_mode(ContextSnapshotRenderMode::RedactedText), + ); + + assert_eq!( + rendered, + "00:message/user::subagents=2>" + ); + } + #[test] fn image_only_message_is_rendered_as_non_text_span() { let items = vec![json!({ diff --git a/codex-rs/core/tests/common/responses.rs b/codex-rs/core/tests/common/responses.rs index ca507557229..bb6de200a97 100644 --- a/codex-rs/core/tests/common/responses.rs +++ b/codex-rs/core/tests/common/responses.rs @@ -12,6 +12,7 @@ use futures::SinkExt; use futures::StreamExt; use serde_json::Value; use tokio::net::TcpListener; +use tokio::sync::Notify; use tokio::sync::oneshot; use tokio_tungstenite::accept_hdr_async_with_config; use tokio_tungstenite::tungstenite::Message; @@ -263,7 +264,7 @@ impl ResponsesRequest { .cloned() .unwrap_or(Value::Null); match output { - Value::String(text) => Some((Some(text), None)), + Value::String(_) | Value::Array(_) => Some((output_value_to_text(&output), None)), Value::Object(obj) => Some(( obj.get("content") .and_then(Value::as_str) @@ -295,6 +296,87 @@ impl ResponsesRequest { } } +pub(crate) fn output_value_to_text(value: &Value) -> Option { + match value { + Value::String(text) => Some(text.clone()), + Value::Array(items) => match items.as_slice() { + [item] if item.get("type").and_then(Value::as_str) == Some("input_text") => { + item.get("text").and_then(Value::as_str).map(str::to_string) + } + [_] | [] | [_, _, ..] => None, + }, + Value::Object(_) | Value::Number(_) | Value::Bool(_) | Value::Null => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + use wiremock::http::HeaderMap; + use wiremock::http::Method; + + fn request_with_input(input: Value) -> ResponsesRequest { + ResponsesRequest(wiremock::Request { + url: "http://localhost/v1/responses" + .parse() + .expect("valid request url"), + method: Method::POST, + headers: HeaderMap::new(), + body: serde_json::to_vec(&serde_json::json!({ "input": input })) + .expect("serialize request body"), + }) + } + + #[test] + fn call_output_content_and_success_returns_only_single_text_content_item() { + let single_text = request_with_input(serde_json::json!([ + { + "type": "function_call_output", + "call_id": "call-1", + "output": [{ "type": "input_text", "text": "hello" }] + }, + { + "type": "custom_tool_call_output", + "call_id": "call-2", + "output": [{ "type": "input_text", "text": "world" }] + } + ])); + assert_eq!( + single_text.function_call_output_content_and_success("call-1"), + Some((Some("hello".to_string()), None)) + ); + assert_eq!( + single_text.custom_tool_call_output_content_and_success("call-2"), + Some((Some("world".to_string()), None)) + ); + + let mixed_content = request_with_input(serde_json::json!([ + { + "type": "function_call_output", + "call_id": "call-3", + "output": [ + { "type": "input_text", "text": "hello" }, + { "type": "input_image", "image_url": "data:image/png;base64,abc" } + ] + }, + { + "type": "custom_tool_call_output", + "call_id": "call-4", + "output": [{ "type": "input_image", "image_url": "data:image/png;base64,abc" }] + } + ])); + assert_eq!( + mixed_content.function_call_output_content_and_success("call-3"), + Some((None, None)) + ); + assert_eq!( + mixed_content.custom_tool_call_output_content_and_success("call-4"), + Some((None, None)) + ); + } +} + #[derive(Debug, Clone)] pub struct WebSocketRequest { body: Value, @@ -335,6 +417,7 @@ pub struct WebSocketTestServer { uri: String, connections: Arc>>>, handshakes: Arc>>, + request_log_updated: Arc, shutdown: oneshot::Sender<()>, task: tokio::task::JoinHandle<()>, } @@ -356,6 +439,26 @@ impl WebSocketTestServer { connections.first().cloned().unwrap_or_default() } + pub async fn wait_for_request( + &self, + connection_index: usize, + request_index: usize, + ) -> WebSocketRequest { + loop { + if let Some(request) = self + .connections + .lock() + .unwrap() + .get(connection_index) + .and_then(|connection| connection.get(request_index)) + .cloned() + { + return request; + } + self.request_log_updated.notified().await; + } + } + pub fn handshakes(&self) -> Vec { self.handshakes.lock().unwrap().clone() } @@ -1069,6 +1172,7 @@ pub async fn start_websocket_server(connections: Vec>>) -> WebSoc pub async fn start_websocket_server_with_headers( connections: Vec, ) -> WebSocketTestServer { + let start = std::time::Instant::now(); let listener = TcpListener::bind("127.0.0.1:0") .await .expect("bind websocket server"); @@ -1076,8 +1180,10 @@ pub async fn start_websocket_server_with_headers( let uri = format!("ws://{addr}"); let connections_log = Arc::new(Mutex::new(Vec::new())); let handshakes_log = Arc::new(Mutex::new(Vec::new())); + let request_log_updated = Arc::new(Notify::new()); let requests = Arc::clone(&connections_log); let handshakes = Arc::clone(&handshakes_log); + let request_log = Arc::clone(&request_log_updated); let connections = Arc::new(Mutex::new(VecDeque::from(connections))); let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); @@ -1159,9 +1265,51 @@ pub async fn start_websocket_server_with_headers( let mut log = requests.lock().unwrap(); if let Some(connection_log) = log.get_mut(connection_index) { connection_log.push(WebSocketRequest { body }); + let request_index = connection_log.len() - 1; + let request = &connection_log[request_index]; + let request_body = request.body_json(); + eprintln!( + "[ws test server +{}ms] connection={} received request={} type={:?} role={:?} text={:?} data={:?}", + start.elapsed().as_millis(), + connection_index, + request_index, + request_body.get("type").and_then(Value::as_str), + request_body + .get("item") + .and_then(|item| item.get("role")) + .and_then(Value::as_str), + request_body + .get("item") + .and_then(|item| item.get("content")) + .and_then(Value::as_array) + .and_then(|content| content.first()) + .and_then(|content| content.get("text")) + .and_then(Value::as_str), + request_body + .get("item") + .and_then(|item| item.get("content")) + .and_then(Value::as_array) + .and_then(|content| content.first()) + .and_then(|content| content.get("data")) + .and_then(Value::as_str), + ); } + request_log.notify_waiters(); } + eprintln!( + "[ws test server +{}ms] connection={} sending batch_size={} event_types={:?} audio_data={:?}", + start.elapsed().as_millis(), + connection_index, + request_events.len(), + request_events + .iter() + .map(|event| event.get("type").and_then(Value::as_str)) + .collect::>(), + request_events + .iter() + .find_map(|event| event.get("delta").and_then(Value::as_str)), + ); for event in &request_events { let Ok(payload) = serde_json::to_string(event) else { continue; @@ -1184,6 +1332,7 @@ pub async fn start_websocket_server_with_headers( uri, connections: connections_log, handshakes: handshakes_log, + request_log_updated, shutdown: shutdown_tx, task, } diff --git a/codex-rs/core/tests/common/test_codex.rs b/codex-rs/core/tests/common/test_codex.rs index 173db5ca776..bd15c6d7e41 100644 --- a/codex-rs/core/tests/common/test_codex.rs +++ b/codex-rs/core/tests/common/test_codex.rs @@ -11,7 +11,6 @@ use codex_core::ThreadManager; use codex_core::built_in_model_providers; use codex_core::config::Config; use codex_core::features::Feature; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::Op; @@ -24,6 +23,7 @@ use wiremock::MockServer; use crate::load_default_config_for_test; use crate::responses::WebSocketTestServer; +use crate::responses::output_value_to_text; use crate::responses::start_mock_server; use crate::streaming_sse::StreamingSseServer; use crate::wait_for_event; @@ -300,7 +300,7 @@ impl TestCodex { sandbox_policy, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -395,11 +395,7 @@ impl TestCodexHarness { pub async fn custom_tool_call_output(&self, call_id: &str) -> String { let bodies = self.request_bodies().await; - custom_tool_call_output(&bodies, call_id) - .get("output") - .and_then(Value::as_str) - .expect("output string") - .to_string() + custom_tool_call_output_text(&bodies, call_id) } pub async fn apply_patch_output( @@ -434,6 +430,14 @@ fn custom_tool_call_output<'a>(bodies: &'a [Value], call_id: &str) -> &'a Value panic!("custom_tool_call_output {call_id} not found"); } +fn custom_tool_call_output_text(bodies: &[Value], call_id: &str) -> String { + let output = custom_tool_call_output(bodies, call_id) + .get("output") + .unwrap_or_else(|| panic!("custom_tool_call_output {call_id} missing output")); + output_value_to_text(output) + .unwrap_or_else(|| panic!("custom_tool_call_output {call_id} missing text output")) +} + fn function_call_output<'a>(bodies: &'a [Value], call_id: &str) -> &'a Value { for body in bodies { if let Some(items) = body.get("input").and_then(Value::as_array) { @@ -457,3 +461,36 @@ pub fn test_codex() -> TestCodexBuilder { home: None, } } + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + use serde_json::json; + + #[test] + fn custom_tool_call_output_text_returns_output_text() { + let bodies = vec![json!({ + "input": [{ + "type": "custom_tool_call_output", + "call_id": "call-1", + "output": "hello" + }] + })]; + + assert_eq!(custom_tool_call_output_text(&bodies, "call-1"), "hello"); + } + + #[test] + #[should_panic(expected = "custom_tool_call_output call-2 missing output")] + fn custom_tool_call_output_text_panics_when_output_is_missing() { + let bodies = vec![json!({ + "input": [{ + "type": "custom_tool_call_output", + "call_id": "call-2" + }] + })]; + + let _ = custom_tool_call_output_text(&bodies, "call-2"); + } +} diff --git a/codex-rs/core/tests/responses_headers.rs b/codex-rs/core/tests/responses_headers.rs index 126dc2c288d..f502ec08be9 100644 --- a/codex-rs/core/tests/responses_headers.rs +++ b/codex-rs/core/tests/responses_headers.rs @@ -111,7 +111,14 @@ async fn responses_stream_includes_subagent_header_on_review() { }]; let mut stream = client_session - .stream(&prompt, &model_info, &otel_manager, effort, summary, None) + .stream( + &prompt, + &model_info, + &otel_manager, + effort, + summary.unwrap_or(model_info.default_reasoning_summary), + None, + ) .await .expect("stream failed"); while let Some(event) = stream.next().await { @@ -216,7 +223,14 @@ async fn responses_stream_includes_subagent_header_on_other() { }]; let mut stream = client_session - .stream(&prompt, &model_info, &otel_manager, effort, summary, None) + .stream( + &prompt, + &model_info, + &otel_manager, + effort, + summary.unwrap_or(model_info.default_reasoning_summary), + None, + ) .await .expect("stream failed"); while let Some(event) = stream.next().await { @@ -267,7 +281,7 @@ async fn responses_respects_model_info_overrides_from_config() { config.model_provider_id = provider.name.clone(); config.model_provider = provider.clone(); config.model_supports_reasoning_summaries = Some(true); - config.model_reasoning_summary = ReasoningSummary::Detailed; + config.model_reasoning_summary = Some(ReasoningSummary::Detailed); let effort = config.model_reasoning_effort; let summary = config.model_reasoning_summary; let model = config.model.clone().expect("model configured"); @@ -320,7 +334,14 @@ async fn responses_respects_model_info_overrides_from_config() { }]; let mut stream = client_session - .stream(&prompt, &model_info, &otel_manager, effort, summary, None) + .stream( + &prompt, + &model_info, + &otel_manager, + effort, + summary.unwrap_or(model_info.default_reasoning_summary), + None, + ) .await .expect("stream failed"); while let Some(event) = stream.next().await { diff --git a/codex-rs/core/tests/suite/apply_patch_cli.rs b/codex-rs/core/tests/suite/apply_patch_cli.rs index 45494124e57..7f20f468d9f 100644 --- a/codex-rs/core/tests/suite/apply_patch_cli.rs +++ b/codex-rs/core/tests/suite/apply_patch_cli.rs @@ -12,7 +12,6 @@ use std::sync::atomic::AtomicI32; use std::sync::atomic::Ordering; use codex_core::features::Feature; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::Op; @@ -312,7 +311,7 @@ async fn apply_patch_cli_move_without_content_change_has_no_turn_diff( sandbox_policy: SandboxPolicy::DangerFullAccess, model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -907,7 +906,7 @@ async fn apply_patch_shell_command_heredoc_with_cd_emits_turn_diff() -> Result<( sandbox_policy: SandboxPolicy::DangerFullAccess, model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -987,7 +986,7 @@ async fn apply_patch_shell_command_failure_propagates_error_and_skips_diff() -> sandbox_policy: SandboxPolicy::DangerFullAccess, model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -1137,7 +1136,7 @@ async fn apply_patch_emits_turn_diff_event_with_unified_diff( sandbox_policy: SandboxPolicy::DangerFullAccess, model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -1200,7 +1199,7 @@ async fn apply_patch_turn_diff_for_rename_with_content_change( sandbox_policy: SandboxPolicy::DangerFullAccess, model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -1271,7 +1270,7 @@ async fn apply_patch_aggregates_diff_across_multiple_tool_calls() -> Result<()> sandbox_policy: SandboxPolicy::DangerFullAccess, model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -1342,7 +1341,7 @@ async fn apply_patch_aggregates_diff_preserves_success_after_failure() -> Result sandbox_policy: SandboxPolicy::DangerFullAccess, model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/approvals.rs b/codex-rs/core/tests/suite/approvals.rs index 34934b3d4b8..a0530809db2 100644 --- a/codex-rs/core/tests/suite/approvals.rs +++ b/codex-rs/core/tests/suite/approvals.rs @@ -13,7 +13,6 @@ use codex_core::sandboxing::SandboxPermissions; use codex_protocol::approvals::NetworkApprovalProtocol; use codex_protocol::approvals::NetworkPolicyAmendment; use codex_protocol::approvals::NetworkPolicyRuleAction; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::protocol::ApplyPatchApprovalRequestEvent; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; @@ -551,7 +550,7 @@ async fn submit_turn( sandbox_policy, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/client.rs b/codex-rs/core/tests/suite/client.rs index 4f67ea6fa7e..1de0522cb55 100644 --- a/codex-rs/core/tests/suite/client.rs +++ b/codex-rs/core/tests/suite/client.rs @@ -31,9 +31,14 @@ use codex_protocol::models::ReasoningItemContent; use codex_protocol::models::ReasoningItemReasoningSummary; use codex_protocol::models::ResponseItem; use codex_protocol::models::WebSearchAction; +use codex_protocol::openai_models::ModelsResponse; use codex_protocol::openai_models::ReasoningEffort; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::Op; +use codex_protocol::protocol::RolloutItem; +use codex_protocol::protocol::RolloutLine; +use codex_protocol::protocol::SessionMeta; +use codex_protocol::protocol::SessionMetaLine; use codex_protocol::protocol::SessionSource; use codex_protocol::user_input::UserInput; use core_test_support::apps_test_server::AppsTestServer; @@ -343,6 +348,144 @@ async fn resume_includes_initial_messages_and_sends_prior_items() { assert!(pos_environment < pos_new_user); } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn resume_replays_legacy_js_repl_image_rollout_shapes() { + skip_if_no_network!(); + + // Early js_repl builds persisted image tool results as two separate rollout items: + // a string-valued custom_tool_call_output plus a standalone user input_image message. + // Current image tests cover today's shapes; this keeps resume compatibility for that + // legacy rollout representation. + let legacy_custom_tool_call = ResponseItem::CustomToolCall { + id: None, + status: None, + call_id: "legacy-js-call".to_string(), + name: "js_repl".to_string(), + input: "console.log('legacy image flow')".to_string(), + }; + let legacy_image_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg=="; + let rollout = vec![ + RolloutLine { + timestamp: "2024-01-01T00:00:00.000Z".to_string(), + item: RolloutItem::SessionMeta(SessionMetaLine { + meta: SessionMeta { + id: ThreadId::default(), + timestamp: "2024-01-01T00:00:00Z".to_string(), + cwd: ".".into(), + originator: "test_originator".to_string(), + cli_version: "test_version".to_string(), + model_provider: Some("test-provider".to_string()), + ..Default::default() + }, + git: None, + }), + }, + RolloutLine { + timestamp: "2024-01-01T00:00:01.000Z".to_string(), + item: RolloutItem::ResponseItem(legacy_custom_tool_call), + }, + RolloutLine { + timestamp: "2024-01-01T00:00:02.000Z".to_string(), + item: RolloutItem::ResponseItem(ResponseItem::CustomToolCallOutput { + call_id: "legacy-js-call".to_string(), + output: FunctionCallOutputPayload::from_text("legacy js_repl stdout".to_string()), + }), + }, + RolloutLine { + timestamp: "2024-01-01T00:00:03.000Z".to_string(), + item: RolloutItem::ResponseItem(ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputImage { + image_url: legacy_image_url.to_string(), + }], + end_turn: None, + phase: None, + }), + }, + ]; + + let tmpdir = TempDir::new().unwrap(); + let session_path = tmpdir + .path() + .join("resume-legacy-js-repl-image-rollout.jsonl"); + let mut f = std::fs::File::create(&session_path).unwrap(); + for line in rollout { + writeln!(f, "{}", serde_json::to_string(&line).unwrap()).unwrap(); + } + + let server = MockServer::start().await; + let resp_mock = mount_sse_once( + &server, + sse(vec![ev_response_created("resp1"), ev_completed("resp1")]), + ) + .await; + + let codex_home = Arc::new(TempDir::new().unwrap()); + let mut builder = test_codex().with_model("gpt-5.1"); + let test = builder + .resume(&server, codex_home, session_path.clone()) + .await + .expect("resume conversation"); + test.submit_turn("after resume").await.unwrap(); + + let input = resp_mock.single_request().input(); + + let legacy_output_index = input + .iter() + .position(|item| { + item.get("type").and_then(|value| value.as_str()) == Some("custom_tool_call_output") + && item.get("call_id").and_then(|value| value.as_str()) == Some("legacy-js-call") + }) + .expect("legacy custom tool output should be replayed"); + assert_eq!( + input[legacy_output_index] + .get("output") + .and_then(|value| value.as_str()), + Some("legacy js_repl stdout") + ); + + let legacy_image_index = input + .iter() + .position(|item| { + item.get("type").and_then(|value| value.as_str()) == Some("message") + && item.get("role").and_then(|value| value.as_str()) == Some("user") + && item + .get("content") + .and_then(|value| value.as_array()) + .is_some_and(|content| { + content.iter().any(|entry| { + entry.get("type").and_then(|value| value.as_str()) + == Some("input_image") + && entry.get("image_url").and_then(|value| value.as_str()) + == Some(legacy_image_url) + }) + }) + }) + .expect("legacy injected image message should be replayed"); + + let new_user_index = input + .iter() + .position(|item| { + item.get("type").and_then(|value| value.as_str()) == Some("message") + && item.get("role").and_then(|value| value.as_str()) == Some("user") + && item + .get("content") + .and_then(|value| value.as_array()) + .is_some_and(|content| { + content.iter().any(|entry| { + entry.get("type").and_then(|value| value.as_str()) == Some("input_text") + && entry.get("text").and_then(|value| value.as_str()) + == Some("after resume") + }) + }) + }) + .expect("new user message should be present"); + + assert!(legacy_output_index < new_user_index); + assert!(legacy_image_index < new_user_index); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn includes_conversation_id_and_model_headers_in_request() { skip_if_no_network!(); @@ -980,7 +1123,11 @@ async fn user_turn_collaboration_mode_overrides_model_and_effort() -> anyhow::Re sandbox_policy: config.permissions.sandbox_policy.get().clone(), model: session_configured.model.clone(), effort: Some(ReasoningEffort::Low), - summary: config.model_reasoning_summary, + summary: Some( + config + .model_reasoning_summary + .unwrap_or(ReasoningSummary::Auto), + ), collaboration_mode: Some(collaboration_mode), final_output_json_schema: None, personality: None, @@ -1014,7 +1161,7 @@ async fn configured_reasoning_summary_is_sent() -> anyhow::Result<()> { .await; let TestCodex { codex, .. } = test_codex() .with_config(|config| { - config.model_reasoning_summary = ReasoningSummary::Concise; + config.model_reasoning_summary = Some(ReasoningSummary::Concise); }) .build(&server) .await?; @@ -1046,6 +1193,75 @@ async fn configured_reasoning_summary_is_sent() -> anyhow::Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn user_turn_explicit_reasoning_summary_overrides_model_catalog_default() -> anyhow::Result<()> +{ + skip_if_no_network!(Ok(())); + let server = MockServer::start().await; + + let resp_mock = mount_sse_once( + &server, + sse(vec![ev_response_created("resp1"), ev_completed("resp1")]), + ) + .await; + + let mut model_catalog: ModelsResponse = + serde_json::from_str(include_str!("../../models.json")).expect("valid models.json"); + let model = model_catalog + .models + .iter_mut() + .find(|model| model.slug == "gpt-5.1") + .expect("gpt-5.1 exists in bundled models.json"); + model.supports_reasoning_summaries = true; + model.default_reasoning_summary = ReasoningSummary::Detailed; + + let TestCodex { + codex, + config, + session_configured, + .. + } = test_codex() + .with_model("gpt-5.1") + .with_config(move |config| { + config.model_catalog = Some(model_catalog); + }) + .build(&server) + .await?; + + codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "hello".into(), + text_elements: Vec::new(), + }], + cwd: config.cwd.clone(), + approval_policy: config.permissions.approval_policy.value(), + sandbox_policy: config.permissions.sandbox_policy.get().clone(), + model: session_configured.model, + effort: None, + summary: Some(ReasoningSummary::Concise), + collaboration_mode: None, + final_output_json_schema: None, + personality: None, + }) + .await + .unwrap(); + + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + let request_body = resp_mock.single_request().body_json(); + + pretty_assertions::assert_eq!( + request_body + .get("reasoning") + .and_then(|reasoning| reasoning.get("summary")) + .and_then(|value| value.as_str()), + Some("concise") + ); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn reasoning_summary_is_omitted_when_disabled() -> anyhow::Result<()> { skip_if_no_network!(Ok(())); @@ -1058,7 +1274,7 @@ async fn reasoning_summary_is_omitted_when_disabled() -> anyhow::Result<()> { .await; let TestCodex { codex, .. } = test_codex() .with_config(|config| { - config.model_reasoning_summary = ReasoningSummary::None; + config.model_reasoning_summary = Some(ReasoningSummary::None); }) .build(&server) .await?; @@ -1089,6 +1305,60 @@ async fn reasoning_summary_is_omitted_when_disabled() -> anyhow::Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn reasoning_summary_none_overrides_model_catalog_default() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + let server = MockServer::start().await; + + let resp_mock = mount_sse_once( + &server, + sse(vec![ev_response_created("resp1"), ev_completed("resp1")]), + ) + .await; + + let mut model_catalog: ModelsResponse = + serde_json::from_str(include_str!("../../models.json")).expect("valid models.json"); + let model = model_catalog + .models + .iter_mut() + .find(|model| model.slug == "gpt-5.1") + .expect("gpt-5.1 exists in bundled models.json"); + model.supports_reasoning_summaries = true; + model.default_reasoning_summary = ReasoningSummary::Detailed; + + let TestCodex { codex, .. } = test_codex() + .with_model("gpt-5.1") + .with_config(move |config| { + config.model_reasoning_summary = Some(ReasoningSummary::None); + config.model_catalog = Some(model_catalog); + }) + .build(&server) + .await?; + + codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "hello".into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await + .unwrap(); + + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + let request_body = resp_mock.single_request().body_json(); + pretty_assertions::assert_eq!( + request_body + .get("reasoning") + .and_then(|reasoning| reasoning.get("summary")), + None + ); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn includes_default_verbosity_in_request() -> anyhow::Result<()> { skip_if_no_network!(Ok(())); @@ -1437,11 +1707,18 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() { }); prompt.input.push(ResponseItem::CustomToolCallOutput { call_id: "custom-tool-call-id".into(), - output: "ok".into(), + output: FunctionCallOutputPayload::from_text("ok".into()), }); let mut stream = client_session - .stream(&prompt, &model_info, &otel_manager, effort, summary, None) + .stream( + &prompt, + &model_info, + &otel_manager, + effort, + summary.unwrap_or(ReasoningSummary::Auto), + None, + ) .await .expect("responses stream to start"); diff --git a/codex-rs/core/tests/suite/collaboration_instructions.rs b/codex-rs/core/tests/suite/collaboration_instructions.rs index 68d695325d3..b1dbd2d6600 100644 --- a/codex-rs/core/tests/suite/collaboration_instructions.rs +++ b/codex-rs/core/tests/suite/collaboration_instructions.rs @@ -169,7 +169,11 @@ async fn collaboration_instructions_added_on_user_turn() -> Result<()> { sandbox_policy: test.config.permissions.sandbox_policy.get().clone(), model: test.session_configured.model.clone(), effort: None, - summary: test.config.model_reasoning_summary, + summary: Some( + test.config + .model_reasoning_summary + .unwrap_or(codex_protocol::config_types::ReasoningSummary::Auto), + ), collaboration_mode: Some(collaboration_mode), final_output_json_schema: None, personality: None, @@ -275,7 +279,11 @@ async fn user_turn_overrides_collaboration_instructions_after_override() -> Resu sandbox_policy: test.config.permissions.sandbox_policy.get().clone(), model: test.session_configured.model.clone(), effort: None, - summary: test.config.model_reasoning_summary, + summary: Some( + test.config + .model_reasoning_summary + .unwrap_or(codex_protocol::config_types::ReasoningSummary::Auto), + ), collaboration_mode: Some(turn_mode), final_output_json_schema: None, personality: None, diff --git a/codex-rs/core/tests/suite/compact.rs b/codex-rs/core/tests/suite/compact.rs index 58a7095aa14..9e376ce494e 100644 --- a/codex-rs/core/tests/suite/compact.rs +++ b/codex-rs/core/tests/suite/compact.rs @@ -5,7 +5,6 @@ use codex_core::built_in_model_providers; use codex_core::compact::SUMMARIZATION_PROMPT; use codex_core::compact::SUMMARY_PREFIX; use codex_core::config::Config; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::items::TurnItem; use codex_protocol::openai_models::ModelInfo; use codex_protocol::openai_models::ModelsResponse; @@ -1659,7 +1658,7 @@ async fn auto_compact_runs_after_resume_when_token_usage_is_over_limit() { sandbox_policy: SandboxPolicy::DangerFullAccess, model: resumed.session_configured.model.clone(), effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -1748,7 +1747,7 @@ async fn pre_sampling_compact_runs_on_switch_to_smaller_context_model() { sandbox_policy: SandboxPolicy::DangerFullAccess, model: previous_model.to_string(), effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -1771,7 +1770,7 @@ async fn pre_sampling_compact_runs_on_switch_to_smaller_context_model() { sandbox_policy: SandboxPolicy::DangerFullAccess, model: next_model.to_string(), effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -1880,7 +1879,7 @@ async fn pre_sampling_compact_runs_after_resume_and_switch_to_smaller_model() { sandbox_policy: SandboxPolicy::DangerFullAccess, model: previous_model.to_string(), effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -1927,7 +1926,7 @@ async fn pre_sampling_compact_runs_after_resume_and_switch_to_smaller_model() { sandbox_policy: SandboxPolicy::DangerFullAccess, model: next_model.to_string(), effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -3128,7 +3127,7 @@ async fn snapshot_request_shape_pre_turn_compaction_strips_incoming_model_switch sandbox_policy: SandboxPolicy::DangerFullAccess, model: previous_model.to_string(), effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -3151,7 +3150,7 @@ async fn snapshot_request_shape_pre_turn_compaction_strips_incoming_model_switch sandbox_policy: SandboxPolicy::DangerFullAccess, model: next_model.to_string(), effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/exec_policy.rs b/codex-rs/core/tests/suite/exec_policy.rs index 3fd06d83c73..52717f28804 100644 --- a/codex-rs/core/tests/suite/exec_policy.rs +++ b/codex-rs/core/tests/suite/exec_policy.rs @@ -4,7 +4,6 @@ use anyhow::Result; use codex_core::features::Feature; use codex_protocol::config_types::CollaborationMode; use codex_protocol::config_types::ModeKind; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::config_types::Settings; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; @@ -55,7 +54,7 @@ async fn submit_user_turn( sandbox_policy, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode, personality: None, }) @@ -134,7 +133,7 @@ async fn execpolicy_blocks_shell_invocation() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/image_rollout.rs b/codex-rs/core/tests/suite/image_rollout.rs index 2349d9f5fad..c8f1f2eaf1e 100644 --- a/codex-rs/core/tests/suite/image_rollout.rs +++ b/codex-rs/core/tests/suite/image_rollout.rs @@ -1,5 +1,4 @@ use anyhow::Context; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseItem; use codex_protocol::protocol::AskForApproval; @@ -126,7 +125,7 @@ async fn copy_paste_local_image_persists_rollout_request_shape() -> anyhow::Resu sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -208,7 +207,7 @@ async fn drag_drop_image_persists_rollout_request_shape() -> anyhow::Result<()> sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/items.rs b/codex-rs/core/tests/suite/items.rs index 0d93278af13..dfb5ac88b47 100644 --- a/codex-rs/core/tests/suite/items.rs +++ b/codex-rs/core/tests/suite/items.rs @@ -377,7 +377,7 @@ async fn plan_mode_emits_plan_item_from_proposed_plan_block() -> anyhow::Result< sandbox_policy: codex_protocol::protocol::SandboxPolicy::DangerFullAccess, model: session_configured.model.clone(), effort: None, - summary: codex_protocol::config_types::ReasoningSummary::Auto, + summary: None, collaboration_mode: Some(collaboration_mode), personality: None, }) @@ -452,7 +452,7 @@ async fn plan_mode_strips_plan_from_agent_messages() -> anyhow::Result<()> { sandbox_policy: codex_protocol::protocol::SandboxPolicy::DangerFullAccess, model: session_configured.model.clone(), effort: None, - summary: codex_protocol::config_types::ReasoningSummary::Auto, + summary: None, collaboration_mode: Some(collaboration_mode), personality: None, }) @@ -559,7 +559,7 @@ async fn plan_mode_streaming_citations_are_stripped_across_added_deltas_and_done sandbox_policy: codex_protocol::protocol::SandboxPolicy::DangerFullAccess, model: session_configured.model.clone(), effort: None, - summary: codex_protocol::config_types::ReasoningSummary::Auto, + summary: None, collaboration_mode: Some(collaboration_mode), personality: None, }) @@ -744,7 +744,7 @@ async fn plan_mode_streaming_proposed_plan_tag_split_across_added_and_delta_is_p sandbox_policy: codex_protocol::protocol::SandboxPolicy::DangerFullAccess, model: session_configured.model.clone(), effort: None, - summary: codex_protocol::config_types::ReasoningSummary::Auto, + summary: None, collaboration_mode: Some(collaboration_mode), personality: None, }) @@ -856,7 +856,7 @@ async fn plan_mode_handles_missing_plan_close_tag() -> anyhow::Result<()> { sandbox_policy: codex_protocol::protocol::SandboxPolicy::DangerFullAccess, model: session_configured.model.clone(), effort: None, - summary: codex_protocol::config_types::ReasoningSummary::Auto, + summary: None, collaboration_mode: Some(collaboration_mode), personality: None, }) diff --git a/codex-rs/core/tests/suite/json_result.rs b/codex-rs/core/tests/suite/json_result.rs index e32865e0c7e..12c9845a756 100644 --- a/codex-rs/core/tests/suite/json_result.rs +++ b/codex-rs/core/tests/suite/json_result.rs @@ -1,6 +1,5 @@ #![cfg(not(target_os = "windows"))] -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::Op; @@ -84,7 +83,7 @@ async fn codex_returns_json_result(model: String) -> anyhow::Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/live_reload.rs b/codex-rs/core/tests/suite/live_reload.rs index 3ea0a3cbe95..8ee4c56ddae 100644 --- a/codex-rs/core/tests/suite/live_reload.rs +++ b/codex-rs/core/tests/suite/live_reload.rs @@ -7,7 +7,6 @@ use std::time::Duration; use anyhow::Result; use codex_core::config::ProjectConfig; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::config_types::TrustLevel; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; @@ -65,7 +64,7 @@ async fn submit_skill_turn(test: &TestCodex, skill_path: PathBuf, prompt: &str) sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/memories.rs b/codex-rs/core/tests/suite/memories.rs index fc46c9df133..8f96178e867 100644 --- a/codex-rs/core/tests/suite/memories.rs +++ b/codex-rs/core/tests/suite/memories.rs @@ -83,6 +83,7 @@ async fn memories_startup_phase2_tracks_added_and_removed_inputs_across_runs() - let rollout_summaries = read_rollout_summary_bodies(&memory_root).await?; assert_eq!(rollout_summaries.len(), 1); assert!(rollout_summaries[0].contains("rollout summary A")); + assert!(rollout_summaries[0].contains("git_branch: branch-rollout-a")); shutdown_test_codex(&first).await?; @@ -141,6 +142,11 @@ async fn memories_startup_phase2_tracks_added_and_removed_inputs_across_runs() - .iter() .any(|summary| summary.contains("rollout summary B")) ); + assert!( + rollout_summaries + .iter() + .any(|summary| summary.contains("git_branch: branch-rollout-b")) + ); assert!( rollout_summaries .iter() @@ -185,6 +191,7 @@ async fn seed_stage1_output( ); metadata_builder.cwd = codex_home.join(format!("workspace-{rollout_slug}")); metadata_builder.model_provider = Some("test-provider".to_string()); + metadata_builder.git_branch = Some(format!("branch-{rollout_slug}")); let metadata = metadata_builder.build("test-provider"); db.upsert_thread(&metadata).await?; @@ -248,7 +255,7 @@ async fn wait_for_phase2_success( ) -> Result<()> { let deadline = Instant::now() + Duration::from_secs(10); loop { - let selection = db.get_phase2_input_selection(1).await?; + let selection = db.get_phase2_input_selection(1, 30).await?; if selection.selected.len() == 1 && selection.selected[0].thread_id == expected_thread_id && selection.retained_thread_ids == vec![expected_thread_id] diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index fd20fc6f8bc..e23fd9ddbbf 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -115,6 +115,7 @@ mod seatbelt; mod shell_command; mod shell_serialization; mod shell_snapshot; +mod skill_approval; mod skills; mod sqlite_state; mod stream_error_allows_next_turn; diff --git a/codex-rs/core/tests/suite/model_switching.rs b/codex-rs/core/tests/suite/model_switching.rs index 6e06b0f2248..62d0545462b 100644 --- a/codex-rs/core/tests/suite/model_switching.rs +++ b/codex-rs/core/tests/suite/model_switching.rs @@ -58,7 +58,7 @@ async fn model_change_appends_model_instructions_developer_message() -> Result<( sandbox_policy: SandboxPolicy::new_read_only_policy(), model: test.session_configured.model.clone(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -91,7 +91,7 @@ async fn model_change_appends_model_instructions_developer_message() -> Result<( sandbox_policy: SandboxPolicy::new_read_only_policy(), model: next_model.to_string(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -146,7 +146,7 @@ async fn model_and_personality_change_only_appends_model_instructions() -> Resul sandbox_policy: SandboxPolicy::new_read_only_policy(), model: test.session_configured.model.clone(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -179,7 +179,7 @@ async fn model_and_personality_change_only_appends_model_instructions() -> Resul sandbox_policy: SandboxPolicy::new_read_only_policy(), model: next_model.to_string(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -234,8 +234,10 @@ async fn model_change_from_image_to_text_strips_prior_image_content() -> Result< base_instructions: "base instructions".to_string(), model_messages: None, supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::Auto, support_verbosity: false, default_verbosity: None, + availability_nux: None, apply_patch_tool_type: None, truncation_policy: TruncationPolicyConfig::bytes(10_000), supports_parallel_tool_calls: false, @@ -293,7 +295,7 @@ async fn model_change_from_image_to_text_strips_prior_image_content() -> Result< sandbox_policy: SandboxPolicy::new_read_only_policy(), model: image_model_slug.to_string(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -312,7 +314,7 @@ async fn model_change_from_image_to_text_strips_prior_image_content() -> Result< sandbox_policy: SandboxPolicy::new_read_only_policy(), model: text_model_slug.to_string(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -391,8 +393,10 @@ async fn model_switch_to_smaller_model_updates_token_context_window() -> Result< base_instructions: "base instructions".to_string(), model_messages: None, supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::Auto, support_verbosity: false, default_verbosity: None, + availability_nux: None, apply_patch_tool_type: None, truncation_policy: TruncationPolicyConfig::bytes(10_000), supports_parallel_tool_calls: false, @@ -469,7 +473,7 @@ async fn model_switch_to_smaller_model_updates_token_context_window() -> Result< sandbox_policy: SandboxPolicy::new_read_only_policy(), model: large_model_slug.to_string(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -524,7 +528,7 @@ async fn model_switch_to_smaller_model_updates_token_context_window() -> Result< sandbox_policy: SandboxPolicy::new_read_only_policy(), model: smaller_model_slug.to_string(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/model_visible_layout.rs b/codex-rs/core/tests/suite/model_visible_layout.rs index 3288eb5c051..14503b3f5b6 100644 --- a/codex-rs/core/tests/suite/model_visible_layout.rs +++ b/codex-rs/core/tests/suite/model_visible_layout.rs @@ -6,7 +6,6 @@ use std::sync::Arc; use anyhow::Result; use codex_core::config::types::Personality; use codex_core::features::Feature; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::Op; @@ -26,6 +25,7 @@ use core_test_support::responses::start_mock_server; use core_test_support::skip_if_no_network; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; +use serde_json::json; const PRETURN_CONTEXT_DIFF_CWD: &str = "PRETURN_CONTEXT_DIFF_CWD"; @@ -53,6 +53,30 @@ fn agents_message_count(request: &ResponsesRequest) -> usize { .count() } +fn format_environment_context_subagents_snapshot(subagents: &[&str]) -> String { + let subagents_block = if subagents.is_empty() { + String::new() + } else { + let lines = subagents + .iter() + .map(|line| format!(" {line}")) + .collect::>() + .join("\n"); + format!("\n \n{lines}\n ") + }; + let items = vec![json!({ + "type": "message", + "role": "user", + "content": [{ + "type": "input_text", + "text": format!( + "\n /tmp/example\n bash{subagents_block}\n" + ), + }], + })]; + context_snapshot::format_response_items_snapshot(items.as_slice(), &context_snapshot_options()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn snapshot_model_visible_layout_turn_overrides() -> Result<()> { skip_if_no_network!(Ok(())); @@ -97,7 +121,7 @@ async fn snapshot_model_visible_layout_turn_overrides() -> Result<()> { sandbox_policy: SandboxPolicy::new_read_only_policy(), model: test.session_configured.model.clone(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -119,7 +143,7 @@ async fn snapshot_model_visible_layout_turn_overrides() -> Result<()> { sandbox_policy: SandboxPolicy::new_read_only_policy(), model: test.session_configured.model.clone(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: Some(Personality::Friendly), }) @@ -196,7 +220,7 @@ async fn snapshot_model_visible_layout_cwd_change_does_not_refresh_agents() -> R sandbox_policy: SandboxPolicy::new_read_only_policy(), model: test.session_configured.model.clone(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -218,7 +242,7 @@ async fn snapshot_model_visible_layout_cwd_change_does_not_refresh_agents() -> R sandbox_policy: SandboxPolicy::new_read_only_policy(), model: test.session_configured.model.clone(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -323,7 +347,7 @@ async fn snapshot_model_visible_layout_resume_with_personality_change() -> Resul sandbox_policy: SandboxPolicy::new_read_only_policy(), model: resumed.session_configured.model.clone(), effort: resumed.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: Some(Personality::Friendly), }) @@ -445,3 +469,23 @@ async fn snapshot_model_visible_layout_resume_override_matches_rollout_model() - Ok(()) } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn snapshot_model_visible_layout_environment_context_includes_one_subagent() -> Result<()> { + insta::assert_snapshot!( + "model_visible_layout_environment_context_includes_one_subagent", + format_environment_context_subagents_snapshot(&["- agent-1: Atlas"]) + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn snapshot_model_visible_layout_environment_context_includes_two_subagents() -> Result<()> { + insta::assert_snapshot!( + "model_visible_layout_environment_context_includes_two_subagents", + format_environment_context_subagents_snapshot(&["- agent-1: Atlas", "- agent-2: Juniper"]) + ); + + Ok(()) +} diff --git a/codex-rs/core/tests/suite/models_cache_ttl.rs b/codex-rs/core/tests/suite/models_cache_ttl.rs index 3422e397b89..9948eb4e21e 100644 --- a/codex-rs/core/tests/suite/models_cache_ttl.rs +++ b/codex-rs/core/tests/suite/models_cache_ttl.rs @@ -99,7 +99,7 @@ async fn renews_cache_ttl_on_matching_models_etag() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: test.session_configured.model.clone(), effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -336,8 +336,10 @@ fn test_remote_model(slug: &str, priority: i32) -> ModelInfo { base_instructions: "base instructions".to_string(), model_messages: None, supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::Auto, support_verbosity: false, default_verbosity: None, + availability_nux: None, apply_patch_tool_type: None, truncation_policy: TruncationPolicyConfig::bytes(10_000), supports_parallel_tool_calls: false, diff --git a/codex-rs/core/tests/suite/models_etag_responses.rs b/codex-rs/core/tests/suite/models_etag_responses.rs index 7c40bec987d..1cdc5490129 100644 --- a/codex-rs/core/tests/suite/models_etag_responses.rs +++ b/codex-rs/core/tests/suite/models_etag_responses.rs @@ -4,7 +4,6 @@ use std::sync::Arc; use anyhow::Result; use codex_core::CodexAuth; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::openai_models::ModelsResponse; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; @@ -106,7 +105,7 @@ async fn refresh_models_on_models_etag_mismatch_and_avoid_duplicate_models_fetch sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/personality.rs b/codex-rs/core/tests/suite/personality.rs index 0b4d485cfb0..97ce2360ba8 100644 --- a/codex-rs/core/tests/suite/personality.rs +++ b/codex-rs/core/tests/suite/personality.rs @@ -97,7 +97,7 @@ async fn user_turn_personality_none_does_not_add_update_message() -> anyhow::Res sandbox_policy: SandboxPolicy::new_read_only_policy(), model: test.session_configured.model.clone(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -143,7 +143,7 @@ async fn config_personality_some_sets_instructions_template() -> anyhow::Result< sandbox_policy: SandboxPolicy::new_read_only_policy(), model: test.session_configured.model.clone(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -196,7 +196,7 @@ async fn config_personality_none_sends_no_personality() -> anyhow::Result<()> { sandbox_policy: SandboxPolicy::new_read_only_policy(), model: test.session_configured.model.clone(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -255,7 +255,7 @@ async fn default_personality_is_pragmatic_without_config_toml() -> anyhow::Resul sandbox_policy: SandboxPolicy::new_read_only_policy(), model: test.session_configured.model.clone(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -302,7 +302,7 @@ async fn user_turn_personality_some_adds_update_message() -> anyhow::Result<()> sandbox_policy: SandboxPolicy::new_read_only_policy(), model: test.session_configured.model.clone(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -336,7 +336,7 @@ async fn user_turn_personality_some_adds_update_message() -> anyhow::Result<()> sandbox_policy: SandboxPolicy::new_read_only_policy(), model: test.session_configured.model.clone(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -398,7 +398,7 @@ async fn user_turn_personality_same_value_does_not_add_update_message() -> anyho sandbox_policy: SandboxPolicy::new_read_only_policy(), model: test.session_configured.model.clone(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -432,7 +432,7 @@ async fn user_turn_personality_same_value_does_not_add_update_message() -> anyho sandbox_policy: SandboxPolicy::new_read_only_policy(), model: test.session_configured.model.clone(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -504,7 +504,7 @@ async fn user_turn_personality_skips_if_feature_disabled() -> anyhow::Result<()> sandbox_policy: SandboxPolicy::new_read_only_policy(), model: test.session_configured.model.clone(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -538,7 +538,7 @@ async fn user_turn_personality_skips_if_feature_disabled() -> anyhow::Result<()> sandbox_policy: SandboxPolicy::new_read_only_policy(), model: test.session_configured.model.clone(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -599,8 +599,10 @@ async fn remote_model_friendly_personality_instructions_with_feature() -> anyhow }), }), supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::Auto, support_verbosity: false, default_verbosity: None, + availability_nux: None, apply_patch_tool_type: None, truncation_policy: TruncationPolicyConfig::bytes(10_000), supports_parallel_tool_calls: false, @@ -646,7 +648,7 @@ async fn remote_model_friendly_personality_instructions_with_feature() -> anyhow sandbox_policy: SandboxPolicy::new_read_only_policy(), model: remote_slug.to_string(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: Some(Personality::Friendly), }) @@ -706,8 +708,10 @@ async fn user_turn_personality_remote_model_template_includes_update_message() - }), }), supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::Auto, support_verbosity: false, default_verbosity: None, + availability_nux: None, apply_patch_tool_type: None, truncation_policy: TruncationPolicyConfig::bytes(10_000), supports_parallel_tool_calls: false, @@ -756,7 +760,7 @@ async fn user_turn_personality_remote_model_template_includes_update_message() - sandbox_policy: SandboxPolicy::new_read_only_policy(), model: remote_slug.to_string(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -790,7 +794,7 @@ async fn user_turn_personality_remote_model_template_includes_update_message() - sandbox_policy: SandboxPolicy::new_read_only_policy(), model: remote_slug.to_string(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/prompt_caching.rs b/codex-rs/core/tests/suite/prompt_caching.rs index cc10fa3754b..54ee11fba47 100644 --- a/codex-rs/core/tests/suite/prompt_caching.rs +++ b/codex-rs/core/tests/suite/prompt_caching.rs @@ -44,14 +44,32 @@ fn text_user_input_parts(texts: Vec) -> serde_json::Value { }) } -fn default_env_context_str(cwd: &str, shell: &Shell) -> String { +fn assert_default_env_context(text: &str, cwd: &str, shell: &Shell) { let shell_name = shell.name(); - format!( - r#" - {cwd} - {shell_name} -"# - ) + assert!( + text.starts_with(ENVIRONMENT_CONTEXT_OPEN_TAG), + "expected environment context fragment: {text}" + ); + assert!( + text.contains(&format!("{cwd}")), + "expected cwd in environment context: {text}" + ); + assert!( + text.contains(&format!("{shell_name}")), + "expected shell in environment context: {text}" + ); + assert!( + text.contains("") && text.contains(""), + "expected current_date in environment context: {text}" + ); + assert!( + text.contains("") && text.contains(""), + "expected timezone in environment context: {text}" + ); + assert!( + text.ends_with(""), + "expected closing environment_context tag: {text}" + ); } fn assert_tool_names(body: &serde_json::Value, expected_names: &[&str]) { @@ -318,10 +336,13 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests let shell = default_user_shell(); let cwd_str = config.cwd.to_string_lossy(); - let expected_env_text = default_env_context_str(&cwd_str, &shell); + let env_text = input1[1]["content"][1]["text"] + .as_str() + .expect("environment context text"); + assert_default_env_context(env_text, &cwd_str, &shell); assert_eq!( - input1[1]["content"][1]["text"].as_str(), - Some(expected_env_text.as_str()), + input1[1]["content"][1]["type"].as_str(), + Some("input_text"), "expected environment context bundled after UI message in cached contextual message" ); assert_eq!(input1[2], text_user_input("hello 1".to_string())); @@ -523,6 +544,18 @@ async fn override_before_first_turn_emits_environment_context() -> anyhow::Resul !env_texts.is_empty(), "expected environment context to be emitted: {env_texts:?}" ); + assert!( + env_texts + .iter() + .any(|text| text.contains("") && text.contains("")), + "expected current_date in environment context: {env_texts:?}" + ); + assert!( + env_texts + .iter() + .any(|text| text.contains("") && text.contains("")), + "expected timezone in environment context: {env_texts:?}" + ); let env_count = input .iter() @@ -646,7 +679,7 @@ async fn per_turn_overrides_keep_cached_prefix_and_key_constant() -> anyhow::Res sandbox_policy: new_policy.clone(), model: "o3".to_string(), effort: Some(ReasoningEffort::High), - summary: ReasoningSummary::Detailed, + summary: Some(ReasoningSummary::Detailed), collaboration_mode: None, final_output_json_schema: None, personality: None, @@ -672,21 +705,6 @@ async fn per_turn_overrides_keep_cached_prefix_and_key_constant() -> anyhow::Res "role": "user", "content": [ { "type": "input_text", "text": "hello 2" } ] }); - let shell = default_user_shell(); - - let expected_env_text_2 = format!( - r#" - {} - {} -"#, - new_cwd.path().display(), - shell.name() - ); - let expected_env_msg_2 = serde_json::json!({ - "type": "message", - "role": "user", - "content": [ { "type": "input_text", "text": expected_env_text_2 } ] - }); let expected_permissions_msg = body1["input"][0].clone(); let body1_input = body1["input"].as_array().expect("input array"); let expected_settings_update_msg = body2["input"][body1_input.len()].clone(); @@ -704,6 +722,14 @@ async fn per_turn_overrides_keep_cached_prefix_and_key_constant() -> anyhow::Res }), "expected model switch section after model override: {expected_settings_update_msg:?}" ); + let expected_env_msg_2 = body2["input"][body1_input.len() + 1].clone(); + assert_eq!(expected_env_msg_2["role"].as_str(), Some("user")); + let env_text = expected_env_msg_2["content"][0]["text"] + .as_str() + .expect("environment context text"); + let shell = default_user_shell(); + let expected_cwd = new_cwd.path().display().to_string(); + assert_default_env_context(env_text, &expected_cwd, &shell); let mut expected_body2 = body1_input.to_vec(); expected_body2.push(expected_settings_update_msg); expected_body2.push(expected_env_msg_2); @@ -761,7 +787,7 @@ async fn send_user_turn_with_no_changes_does_not_send_environment_context() -> a sandbox_policy: default_sandbox_policy.clone(), model: default_model.clone(), effort: default_effort, - summary: default_summary, + summary: Some(default_summary.unwrap_or(ReasoningSummary::Auto)), collaboration_mode: None, final_output_json_schema: None, personality: None, @@ -780,7 +806,7 @@ async fn send_user_turn_with_no_changes_does_not_send_environment_context() -> a sandbox_policy: default_sandbox_policy.clone(), model: default_model.clone(), effort: default_effort, - summary: default_summary, + summary: Some(default_summary.unwrap_or(ReasoningSummary::Auto)), collaboration_mode: None, final_output_json_schema: None, personality: None, @@ -798,13 +824,18 @@ async fn send_user_turn_with_no_changes_does_not_send_environment_context() -> a let shell = default_user_shell(); let default_cwd_lossy = default_cwd.to_string_lossy(); + let expected_env_text_1 = expected_ui_msg["content"][1]["text"] + .as_str() + .expect("cached environment context text") + .to_string(); + assert_default_env_context(&expected_env_text_1, &default_cwd_lossy, &shell); let expected_contextual_user_msg_1 = text_user_input_parts(vec![ expected_ui_msg["content"][0]["text"] .as_str() .expect("cached user instructions text") .to_string(), - default_env_context_str(&default_cwd_lossy, &shell), + expected_env_text_1, ]); let expected_user_message_1 = text_user_input("hello 1".to_string()); @@ -875,7 +906,7 @@ async fn send_user_turn_with_changes_sends_environment_context() -> anyhow::Resu sandbox_policy: default_sandbox_policy.clone(), model: default_model, effort: default_effort, - summary: default_summary, + summary: Some(default_summary.unwrap_or(ReasoningSummary::Auto)), collaboration_mode: None, final_output_json_schema: None, personality: None, @@ -894,7 +925,7 @@ async fn send_user_turn_with_changes_sends_environment_context() -> anyhow::Resu sandbox_policy: SandboxPolicy::DangerFullAccess, model: "o3".to_string(), effort: Some(ReasoningEffort::High), - summary: ReasoningSummary::Detailed, + summary: Some(ReasoningSummary::Detailed), collaboration_mode: None, final_output_json_schema: None, personality: None, @@ -911,7 +942,11 @@ async fn send_user_turn_with_changes_sends_environment_context() -> anyhow::Resu let expected_ui_msg = body1["input"][1].clone(); let shell = default_user_shell(); - let expected_env_text_1 = default_env_context_str(&default_cwd.to_string_lossy(), &shell); + let expected_env_text_1 = expected_ui_msg["content"][1]["text"] + .as_str() + .expect("cached environment context text") + .to_string(); + assert_default_env_context(&expected_env_text_1, &default_cwd.to_string_lossy(), &shell); let expected_contextual_user_msg_1 = text_user_input_parts(vec![ expected_ui_msg["content"][0]["text"] .as_str() diff --git a/codex-rs/core/tests/suite/realtime_conversation.rs b/codex-rs/core/tests/suite/realtime_conversation.rs index 98848e911cd..1f7504d41b7 100644 --- a/codex-rs/core/tests/suite/realtime_conversation.rs +++ b/codex-rs/core/tests/suite/realtime_conversation.rs @@ -715,6 +715,7 @@ async fn inbound_realtime_text_ignores_user_role_and_still_forwards_audio() -> R async fn delegated_turn_user_role_echo_does_not_redelegate_and_still_forwards_audio() -> Result<()> { skip_if_no_network!(Ok(())); + let start = std::time::Instant::now(); let (gate_completed_tx, gate_completed_rx) = oneshot::channel(); let first_chunks = vec![ @@ -806,18 +807,45 @@ async fn delegated_turn_user_role_echo_does_not_redelegate_and_still_forwards_au _ => None, }) .await; + eprintln!( + "[realtime test +{}ms] saw trigger text={:?}", + start.elapsed().as_millis(), + "delegate now" + ); - let audio_out = tokio::time::timeout( - Duration::from_millis(500), - wait_for_event_match(&test.codex, |msg| match msg { - EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent { - payload: RealtimeEvent::AudioOut(frame), - }) => Some(frame.clone()), - _ => None, - }), - ) - .await - .expect("timed out waiting for realtime audio after echoed user-role message"); + let mirrored_request = realtime_server.wait_for_request(0, 1).await; + let mirrored_request_body = mirrored_request.body_json(); + eprintln!( + "[realtime test +{}ms] saw mirrored request type={:?} role={:?} text={:?} data={:?}", + start.elapsed().as_millis(), + mirrored_request_body["type"].as_str(), + mirrored_request_body["item"]["role"].as_str(), + mirrored_request_body["item"]["content"][0]["text"].as_str(), + mirrored_request_body["item"]["content"][0]["data"].as_str(), + ); + assert_eq!( + mirrored_request_body["type"].as_str(), + Some("conversation.item.create") + ); + assert_eq!( + mirrored_request_body["item"]["content"][0]["text"].as_str(), + Some("assistant says hi") + ); + + let audio_out = wait_for_event_match(&test.codex, |msg| match msg { + EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent { + payload: RealtimeEvent::AudioOut(frame), + }) => Some(frame.clone()), + _ => None, + }) + .await; + eprintln!( + "[realtime test +{}ms] saw audio out data={} sample_rate={} num_channels={}", + start.elapsed().as_millis(), + audio_out.data, + audio_out.sample_rate, + audio_out.num_channels + ); assert_eq!(audio_out.data, "AQID"); let completion = completions @@ -828,6 +856,10 @@ async fn delegated_turn_user_role_echo_does_not_redelegate_and_still_forwards_au completion .await .expect("delegated turn request did not complete"); + eprintln!( + "[realtime test +{}ms] delegated completion resolved", + start.elapsed().as_millis() + ); wait_for_event(&test.codex, |event| { matches!(event, EventMsg::TurnComplete(_)) }) diff --git a/codex-rs/core/tests/suite/remote_models.rs b/codex-rs/core/tests/suite/remote_models.rs index 166f0ceac6c..c304cc25fa0 100644 --- a/codex-rs/core/tests/suite/remote_models.rs +++ b/codex-rs/core/tests/suite/remote_models.rs @@ -139,6 +139,7 @@ async fn remote_models_long_model_slug_is_sent_with_high_reasoning() -> Result<( }, ]; remote_model.supports_reasoning_summaries = true; + remote_model.default_reasoning_summary = ReasoningSummary::Detailed; mount_models_once( &server, ModelsResponse { @@ -175,7 +176,7 @@ async fn remote_models_long_model_slug_is_sent_with_high_reasoning() -> Result<( sandbox_policy: config.permissions.sandbox_policy.get().clone(), model: requested_model.to_string(), effort: None, - summary: config.model_reasoning_summary, + summary: None, collaboration_mode: None, personality: None, }) @@ -189,8 +190,13 @@ async fn remote_models_long_model_slug_is_sent_with_high_reasoning() -> Result<( .get("reasoning") .and_then(|reasoning| reasoning.get("effort")) .and_then(|value| value.as_str()); + let reasoning_summary = body + .get("reasoning") + .and_then(|reasoning| reasoning.get("summary")) + .and_then(|value| value.as_str()); assert_eq!(body["model"].as_str(), Some(requested_model)); assert_eq!(reasoning_effort, Some("high")); + assert_eq!(reasoning_summary, Some("detailed")); Ok(()) } @@ -227,7 +233,11 @@ async fn namespaced_model_slug_uses_catalog_metadata_without_fallback_warning() sandbox_policy: config.permissions.sandbox_policy.get().clone(), model: requested_model.to_string(), effort: None, - summary: config.model_reasoning_summary, + summary: Some( + config + .model_reasoning_summary + .unwrap_or(ReasoningSummary::Auto), + ), collaboration_mode: None, personality: None, }) @@ -284,8 +294,10 @@ async fn remote_models_remote_model_uses_unified_exec() -> Result<()> { base_instructions: "base instructions".to_string(), model_messages: None, supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::Auto, support_verbosity: false, default_verbosity: None, + availability_nux: None, apply_patch_tool_type: None, truncation_policy: TruncationPolicyConfig::bytes(10_000), supports_parallel_tool_calls: false, @@ -379,7 +391,7 @@ async fn remote_models_remote_model_uses_unified_exec() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: REMOTE_MODEL_SLUG.to_string(), effort: None, - summary: ReasoningSummary::Auto, + summary: Some(ReasoningSummary::Auto), collaboration_mode: None, personality: None, }) @@ -520,8 +532,10 @@ async fn remote_models_apply_remote_base_instructions() -> Result<()> { base_instructions: remote_base.to_string(), model_messages: None, supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::Auto, support_verbosity: false, default_verbosity: None, + availability_nux: None, apply_patch_tool_type: None, truncation_policy: TruncationPolicyConfig::bytes(10_000), supports_parallel_tool_calls: false, @@ -590,7 +604,7 @@ async fn remote_models_apply_remote_base_instructions() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: model.to_string(), effort: None, - summary: ReasoningSummary::Auto, + summary: Some(ReasoningSummary::Auto), collaboration_mode: None, personality: None, }) @@ -980,8 +994,10 @@ fn test_remote_model_with_policy( base_instructions: "base instructions".to_string(), model_messages: None, supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::Auto, support_verbosity: false, default_verbosity: None, + availability_nux: None, apply_patch_tool_type: None, truncation_policy, supports_parallel_tool_calls: false, diff --git a/codex-rs/core/tests/suite/request_permissions.rs b/codex-rs/core/tests/suite/request_permissions.rs index cda482079a2..43ed8941d17 100644 --- a/codex-rs/core/tests/suite/request_permissions.rs +++ b/codex-rs/core/tests/suite/request_permissions.rs @@ -4,7 +4,6 @@ use anyhow::Result; use codex_core::config::Constrained; use codex_core::features::Feature; use codex_core::sandboxing::SandboxPermissions; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::models::FileSystemPermissions; use codex_protocol::models::PermissionProfile; use codex_protocol::protocol::AskForApproval; @@ -111,7 +110,7 @@ async fn submit_turn( sandbox_policy, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/request_user_input.rs b/codex-rs/core/tests/suite/request_user_input.rs index 64f2e0b6db6..1f20f5dfdf0 100644 --- a/codex-rs/core/tests/suite/request_user_input.rs +++ b/codex-rs/core/tests/suite/request_user_input.rs @@ -5,7 +5,6 @@ use std::collections::HashMap; use codex_core::features::Feature; use codex_protocol::config_types::CollaborationMode; use codex_protocol::config_types::ModeKind; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::config_types::Settings; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; @@ -138,7 +137,7 @@ async fn request_user_input_round_trip_for_mode(mode: ModeKind) -> anyhow::Resul sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: Some(CollaborationMode { mode, settings: Settings { @@ -254,7 +253,7 @@ where sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: Some(collaboration_mode), personality: None, }) diff --git a/codex-rs/core/tests/suite/resume_warning.rs b/codex-rs/core/tests/suite/resume_warning.rs index b0503d5929c..4742b417ae8 100644 --- a/codex-rs/core/tests/suite/resume_warning.rs +++ b/codex-rs/core/tests/suite/resume_warning.rs @@ -4,6 +4,7 @@ use codex_core::CodexAuth; use codex_core::NewThread; use codex_protocol::ThreadId; use codex_protocol::config_types::ModeKind; +use codex_protocol::config_types::ReasoningSummary; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::InitialHistory; use codex_protocol::protocol::ResumedHistory; @@ -27,6 +28,8 @@ fn resume_history( let turn_ctx = TurnContextItem { turn_id: Some(turn_id.clone()), cwd: config.cwd.clone(), + current_date: None, + timezone: None, approval_policy: config.permissions.approval_policy.value(), sandbox_policy: config.permissions.sandbox_policy.get().clone(), network: None, @@ -34,7 +37,9 @@ fn resume_history( personality: None, collaboration_mode: None, effort: config.model_reasoning_effort, - summary: config.model_reasoning_summary, + summary: config + .model_reasoning_summary + .unwrap_or(ReasoningSummary::Auto), user_instructions: None, developer_instructions: None, final_output_json_schema: None, diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index df389c5ef74..dfbac85a096 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -104,6 +104,7 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> { enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, ); config @@ -128,7 +129,7 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> { sandbox_policy: SandboxPolicy::new_read_only_policy(), model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -246,6 +247,7 @@ async fn stdio_image_responses_round_trip() -> anyhow::Result<()> { enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, ); config @@ -295,7 +297,7 @@ async fn stdio_image_responses_round_trip() -> anyhow::Result<()> { sandbox_policy: SandboxPolicy::new_read_only_policy(), model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -398,8 +400,10 @@ async fn stdio_image_responses_are_sanitized_for_text_only_model() -> anyhow::Re base_instructions: "base instructions".to_string(), model_messages: None, supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::Auto, support_verbosity: false, default_verbosity: None, + availability_nux: None, apply_patch_tool_type: None, truncation_policy: TruncationPolicyConfig::bytes(10_000), supports_parallel_tool_calls: false, @@ -462,6 +466,7 @@ async fn stdio_image_responses_are_sanitized_for_text_only_model() -> anyhow::Re enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, ); config @@ -492,7 +497,7 @@ async fn stdio_image_responses_are_sanitized_for_text_only_model() -> anyhow::Re sandbox_policy: SandboxPolicy::new_read_only_policy(), model: text_only_model_slug.to_string(), effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -580,6 +585,7 @@ async fn stdio_server_propagates_whitelisted_env_vars() -> anyhow::Result<()> { enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, ); config @@ -604,7 +610,7 @@ async fn stdio_server_propagates_whitelisted_env_vars() -> anyhow::Result<()> { sandbox_policy: SandboxPolicy::new_read_only_policy(), model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -739,6 +745,7 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> { enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, ); config @@ -763,7 +770,7 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> { sandbox_policy: SandboxPolicy::new_read_only_policy(), model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -958,6 +965,7 @@ async fn streamable_http_with_oauth_round_trip_impl() -> anyhow::Result<()> { enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, ); config @@ -982,7 +990,7 @@ async fn streamable_http_with_oauth_round_trip_impl() -> anyhow::Result<()> { sandbox_policy: SandboxPolicy::new_read_only_policy(), model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/safety_check_downgrade.rs b/codex-rs/core/tests/suite/safety_check_downgrade.rs index 247cfdaedc0..263e4b96dc2 100644 --- a/codex-rs/core/tests/suite/safety_check_downgrade.rs +++ b/codex-rs/core/tests/suite/safety_check_downgrade.rs @@ -1,5 +1,4 @@ use anyhow::Result; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseItem; use codex_protocol::protocol::AskForApproval; @@ -49,7 +48,7 @@ async fn openai_model_header_mismatch_emits_warning_event_and_warning_item() -> sandbox_policy: SandboxPolicy::DangerFullAccess, model: REQUESTED_MODEL.to_string(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -146,7 +145,7 @@ async fn response_model_field_mismatch_emits_warning_when_header_matches_request sandbox_policy: SandboxPolicy::DangerFullAccess, model: REQUESTED_MODEL.to_string(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -230,7 +229,7 @@ async fn openai_model_header_mismatch_only_emits_one_warning_per_turn() -> Resul sandbox_policy: SandboxPolicy::DangerFullAccess, model: REQUESTED_MODEL.to_string(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -278,7 +277,7 @@ async fn openai_model_header_casing_only_mismatch_does_not_warn() -> Result<()> sandbox_policy: SandboxPolicy::DangerFullAccess, model: REQUESTED_MODEL.to_string(), effort: test.config.model_reasoning_effort, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/search_tool.rs b/codex-rs/core/tests/suite/search_tool.rs index 1caab45b12b..d852c9cfb30 100644 --- a/codex-rs/core/tests/suite/search_tool.rs +++ b/codex-rs/core/tests/suite/search_tool.rs @@ -133,6 +133,7 @@ fn rmcp_server_config(command: String) -> McpServerConfig { enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, } } diff --git a/codex-rs/core/tests/suite/shell_snapshot.rs b/codex-rs/core/tests/suite/shell_snapshot.rs index 3be9d2b5798..a6f8ada10f6 100644 --- a/codex-rs/core/tests/suite/shell_snapshot.rs +++ b/codex-rs/core/tests/suite/shell_snapshot.rs @@ -1,6 +1,5 @@ use anyhow::Result; use codex_core::features::Feature; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::ExecCommandBeginEvent; @@ -162,7 +161,7 @@ async fn run_snapshot_command_with_options( sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -248,7 +247,7 @@ async fn run_shell_command_snapshot_with_options( sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -317,7 +316,7 @@ async fn run_tool_turn_on_harness( sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -535,7 +534,7 @@ async fn shell_command_snapshot_still_intercepts_apply_patch() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/skill_approval.rs b/codex-rs/core/tests/suite/skill_approval.rs index a3cfd4d7fdf..a51c69bcfa3 100644 --- a/codex-rs/core/tests/suite/skill_approval.rs +++ b/codex-rs/core/tests/suite/skill_approval.rs @@ -2,13 +2,15 @@ #![cfg(unix)] use anyhow::Result; +use codex_core::config::Config; use codex_core::features::Feature; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::models::FileSystemPermissions; use codex_protocol::models::PermissionProfile; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; +use codex_protocol::protocol::ExecApprovalRequestEvent; use codex_protocol::protocol::Op; +use codex_protocol::protocol::ReviewDecision; use codex_protocol::protocol::SandboxPolicy; use codex_protocol::user_input::UserInput; use core_test_support::responses::mount_function_call_agent_response; @@ -56,7 +58,7 @@ async fn submit_turn_with_policies( sandbox_policy, model: test.session_configured.model.clone(), effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -65,6 +67,24 @@ async fn submit_turn_with_policies( } fn write_skill_with_shell_script(home: &Path, name: &str, script_name: &str) -> Result { + write_skill_with_shell_script_contents( + home, + name, + script_name, + r#"#!/bin/sh +echo 'zsh-fork-stdout' +echo 'zsh-fork-stderr' >&2 +"#, + ) +} + +#[cfg(unix)] +fn write_skill_with_shell_script_contents( + home: &Path, + name: &str, + script_name: &str, + script_contents: &str, +) -> Result { use std::os::unix::fs::PermissionsExt; let skill_dir = home.join("skills").join(name); @@ -82,13 +102,7 @@ description: {name} skill )?; let script_path = scripts_dir.join(script_name); - fs::write( - &script_path, - r#"#!/bin/sh -echo 'zsh-fork-stdout' -echo 'zsh-fork-stderr' >&2 -"#, - )?; + fs::write(&script_path, script_contents)?; let mut permissions = fs::metadata(&script_path)?.permissions(); permissions.set_mode(0o755); fs::set_permissions(&script_path, permissions)?; @@ -129,34 +143,134 @@ fn supports_exec_wrapper_intercept(zsh_path: &Path) -> bool { } } -#[cfg(unix)] -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn shell_zsh_fork_prompts_for_skill_script_execution() -> Result<()> { - use codex_config::Constrained; - use codex_protocol::protocol::ReviewDecision; +#[derive(Clone)] +struct ZshForkRuntime { + zsh_path: PathBuf, + main_execve_wrapper_exe: PathBuf, +} - skip_if_no_network!(Ok(())); +impl ZshForkRuntime { + fn apply_to_config( + &self, + config: &mut Config, + approval_policy: AskForApproval, + sandbox_policy: SandboxPolicy, + ) { + use codex_config::Constrained; + config.features.enable(Feature::ShellTool); + config.features.enable(Feature::ShellZshFork); + config.zsh_path = Some(self.zsh_path.clone()); + config.main_execve_wrapper_exe = Some(self.main_execve_wrapper_exe.clone()); + config.permissions.allow_login_shell = false; + config.permissions.approval_policy = Constrained::allow_any(approval_policy); + config.permissions.sandbox_policy = Constrained::allow_any(sandbox_policy); + } +} + +fn restrictive_workspace_write_policy() -> SandboxPolicy { + SandboxPolicy::WorkspaceWrite { + writable_roots: Vec::new(), + read_only_access: Default::default(), + network_access: false, + exclude_tmpdir_env_var: true, + exclude_slash_tmp: true, + } +} + +fn zsh_fork_runtime(test_name: &str) -> Result> { let Some(zsh_path) = find_test_zsh_path()? else { - return Ok(()); + return Ok(None); }; if !supports_exec_wrapper_intercept(&zsh_path) { eprintln!( - "skipping zsh-fork skill test: zsh does not support EXEC_WRAPPER intercepts ({})", + "skipping {test_name}: zsh does not support EXEC_WRAPPER intercepts ({})", zsh_path.display() ); - return Ok(()); + return Ok(None); } let Ok(main_execve_wrapper_exe) = codex_utils_cargo_bin::cargo_bin("codex-execve-wrapper") else { - eprintln!("skipping zsh-fork skill test: unable to resolve `codex-execve-wrapper` binary"); + eprintln!("skipping {test_name}: unable to resolve `codex-execve-wrapper` binary"); + return Ok(None); + }; + + Ok(Some(ZshForkRuntime { + zsh_path, + main_execve_wrapper_exe, + })) +} + +async fn build_zsh_fork_test( + server: &wiremock::MockServer, + runtime: ZshForkRuntime, + approval_policy: AskForApproval, + sandbox_policy: SandboxPolicy, + pre_build_hook: F, +) -> Result +where + F: FnOnce(&Path) + Send + 'static, +{ + let mut builder = test_codex() + .with_pre_build_hook(pre_build_hook) + .with_config(move |config| { + runtime.apply_to_config(config, approval_policy, sandbox_policy); + }); + builder.build(server).await +} + +fn skill_script_command(test: &TestCodex, script_name: &str) -> Result<(String, String)> { + let script_path = fs::canonicalize( + test.codex_home_path() + .join("skills/mbolin-test-skill/scripts") + .join(script_name), + )?; + let script_path_str = script_path.to_string_lossy().into_owned(); + let command = shlex::try_join([script_path_str.as_str()])?; + Ok((script_path_str, command)) +} + +async fn wait_for_exec_approval_request(test: &TestCodex) -> Option { + wait_for_event_match(test.codex.as_ref(), |event| match event { + EventMsg::ExecApprovalRequest(request) => Some(Some(request.clone())), + EventMsg::TurnComplete(_) => Some(None), + _ => None, + }) + .await +} + +async fn wait_for_turn_complete(test: &TestCodex) { + wait_for_event(test.codex.as_ref(), |event| { + matches!(event, EventMsg::TurnComplete(_)) + }) + .await; +} + +fn output_shows_sandbox_denial(output: &str) -> bool { + output.contains("Permission denied") + || output.contains("Operation not permitted") + || output.contains("Read-only file system") +} + +/// Focus on the approval payload: the skill should prompt before execution and +/// only advertise the permissions declared in its metadata. +#[cfg(unix)] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn shell_zsh_fork_prompts_for_skill_script_execution() -> Result<()> { + skip_if_no_network!(Ok(())); + + let Some(runtime) = zsh_fork_runtime("zsh-fork skill prompt test")? else { return Ok(()); }; let server = start_mock_server().await; let tool_call_id = "zsh-fork-skill-call"; - let mut builder = test_codex() - .with_pre_build_hook(|home| { + let test = build_zsh_fork_test( + &server, + runtime, + AskForApproval::OnRequest, + SandboxPolicy::new_workspace_write_policy(), + |home| { write_skill_with_shell_script(home, "mbolin-test-skill", "hello-mbolin.sh").unwrap(); write_skill_metadata( home, @@ -171,25 +285,11 @@ permissions: "#, ) .unwrap(); - }) - .with_config(move |config| { - config.features.enable(Feature::ShellTool); - config.features.enable(Feature::ShellZshFork); - config.zsh_path = Some(zsh_path.clone()); - config.main_execve_wrapper_exe = Some(main_execve_wrapper_exe); - config.permissions.allow_login_shell = false; - config.permissions.approval_policy = Constrained::allow_any(AskForApproval::OnRequest); - config.permissions.sandbox_policy = - Constrained::allow_any(SandboxPolicy::new_workspace_write_policy()); - }); - let test = builder.build(&server).await?; + }, + ) + .await?; - let script_path = fs::canonicalize( - test.codex_home_path() - .join("skills/mbolin-test-skill/scripts/hello-mbolin.sh"), - )?; - let script_path_str = script_path.to_string_lossy().into_owned(); - let command = shlex::try_join([script_path_str.as_str()])?; + let (script_path_str, command) = skill_script_command(&test, "hello-mbolin.sh")?; let arguments = shell_command_arguments(&command)?; let mocks = mount_function_call_agent_response(&server, tool_call_id, &arguments, "shell_command") @@ -203,12 +303,7 @@ permissions: ) .await?; - let maybe_approval = wait_for_event_match(test.codex.as_ref(), |event| match event { - EventMsg::ExecApprovalRequest(request) => Some(Some(request.clone())), - EventMsg::TurnComplete(_) => Some(None), - _ => None, - }) - .await; + let maybe_approval = wait_for_exec_approval_request(&test).await; let approval = match maybe_approval { Some(approval) => approval, None => { @@ -250,10 +345,7 @@ permissions: }) .await?; - wait_for_event(test.codex.as_ref(), |event| { - matches!(event, EventMsg::TurnComplete(_)) - }) - .await; + wait_for_turn_complete(&test).await; let call_output = mocks .completion @@ -268,58 +360,350 @@ permissions: Ok(()) } +/// Look for `additional_permissions == None`, then verify that both the first +/// run and the cached session-approval rerun stay inside the turn sandbox. #[cfg(unix)] #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn shell_zsh_fork_still_enforces_workspace_write_sandbox() -> Result<()> { - use codex_config::Constrained; - use codex_protocol::protocol::AskForApproval; - +async fn shell_zsh_fork_skill_without_permissions_inherits_turn_sandbox() -> Result<()> { skip_if_no_network!(Ok(())); - let Some(zsh_path) = find_test_zsh_path()? else { + let Some(runtime) = zsh_fork_runtime("zsh-fork inherited skill sandbox test")? else { return Ok(()); }; - if !supports_exec_wrapper_intercept(&zsh_path) { - eprintln!( - "skipping zsh-fork sandbox test: zsh does not support EXEC_WRAPPER intercepts ({})", - zsh_path.display() - ); + + let outside_dir = tempfile::tempdir_in(std::env::current_dir()?)?; + let outside_path = outside_dir + .path() + .join("zsh-fork-skill-inherited-sandbox.txt"); + let outside_path_quoted = shlex::try_join([outside_path.to_string_lossy().as_ref()])?; + let script_contents = format!( + "#!/bin/sh\nprintf '%s' forbidden > {outside_path_quoted}\ncat {outside_path_quoted}\n" + ); + let outside_path_for_hook = outside_path.clone(); + let script_contents_for_hook = script_contents.clone(); + let workspace_write_policy = restrictive_workspace_write_policy(); + + let server = start_mock_server().await; + let test = build_zsh_fork_test( + &server, + runtime, + AskForApproval::OnRequest, + workspace_write_policy.clone(), + move |home| { + let _ = fs::remove_file(&outside_path_for_hook); + write_skill_with_shell_script_contents( + home, + "mbolin-test-skill", + "sandboxed.sh", + &script_contents_for_hook, + ) + .unwrap(); + }, + ) + .await?; + + let (script_path_str, command) = skill_script_command(&test, "sandboxed.sh")?; + + let first_call_id = "zsh-fork-skill-permissions-1"; + let first_arguments = shell_command_arguments(&command)?; + let first_mocks = mount_function_call_agent_response( + &server, + first_call_id, + &first_arguments, + "shell_command", + ) + .await; + + submit_turn_with_policies( + &test, + "use $mbolin-test-skill", + AskForApproval::OnRequest, + workspace_write_policy.clone(), + ) + .await?; + + let maybe_approval = wait_for_exec_approval_request(&test).await; + let approval = match maybe_approval { + Some(approval) => approval, + None => panic!("expected exec approval request before completion"), + }; + assert_eq!(approval.call_id, first_call_id); + assert_eq!(approval.command, vec![script_path_str.clone()]); + assert_eq!(approval.additional_permissions, None); + + test.codex + .submit(Op::ExecApproval { + id: approval.effective_approval_id(), + turn_id: None, + decision: ReviewDecision::ApprovedForSession, + }) + .await?; + + wait_for_turn_complete(&test).await; + + let first_output = first_mocks + .completion + .single_request() + .function_call_output(first_call_id)["output"] + .as_str() + .unwrap_or_default() + .to_string(); + assert!( + output_shows_sandbox_denial(&first_output) || !first_output.contains("forbidden"), + "expected inherited turn sandbox denial on first run, got output: {first_output:?}" + ); + assert!( + !outside_path.exists(), + "first run should not write outside the turn sandbox" + ); + + let second_call_id = "zsh-fork-skill-permissions-2"; + let second_arguments = shell_command_arguments(&command)?; + let second_mocks = mount_function_call_agent_response( + &server, + second_call_id, + &second_arguments, + "shell_command", + ) + .await; + + submit_turn_with_policies( + &test, + "use $mbolin-test-skill", + AskForApproval::OnRequest, + workspace_write_policy, + ) + .await?; + + let cached_approval = wait_for_exec_approval_request(&test).await; + assert!( + cached_approval.is_none(), + "expected second run to reuse the cached session approval" + ); + + let second_output = second_mocks + .completion + .single_request() + .function_call_output(second_call_id)["output"] + .as_str() + .unwrap_or_default() + .to_string(); + assert!( + output_shows_sandbox_denial(&second_output) || !second_output.contains("forbidden"), + "expected cached skill approval to retain inherited turn sandboxing, got output: {second_output:?}" + ); + assert!( + !outside_path.exists(), + "cached session approval should not widen a permissionless skill to full access" + ); + + Ok(()) +} + +/// The validation to focus on is: writes to the skill-approved folder succeed, +/// and writes to an unrelated folder fail, both before and after cached approval. +#[cfg(unix)] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn shell_zsh_fork_skill_session_approval_enforces_skill_permissions() -> Result<()> { + skip_if_no_network!(Ok(())); + + let Some(runtime) = zsh_fork_runtime("zsh-fork explicit skill sandbox test")? else { return Ok(()); - } - let Ok(main_execve_wrapper_exe) = codex_utils_cargo_bin::cargo_bin("codex-execve-wrapper") - else { - eprintln!( - "skipping zsh-fork sandbox test: unable to resolve `codex-execve-wrapper` binary" - ); + }; + + let outside_dir = tempfile::tempdir_in(std::env::current_dir()?)?; + let allowed_dir = outside_dir.path().join("allowed-output"); + let blocked_dir = outside_dir.path().join("blocked-output"); + fs::create_dir_all(&allowed_dir)?; + fs::create_dir_all(&blocked_dir)?; + + let allowed_path = allowed_dir.join("allowed.txt"); + let blocked_path = blocked_dir.join("blocked.txt"); + let allowed_path_quoted = shlex::try_join([allowed_path.to_string_lossy().as_ref()])?; + let blocked_path_quoted = shlex::try_join([blocked_path.to_string_lossy().as_ref()])?; + let script_contents = format!( + "#!/bin/sh\nprintf '%s' allowed > {allowed_path_quoted}\ncat {allowed_path_quoted}\nprintf '%s' forbidden > {blocked_path_quoted}\nif [ -f {blocked_path_quoted} ]; then echo blocked-created; fi\n" + ); + let allowed_dir_for_hook = allowed_dir.clone(); + let allowed_path_for_hook = allowed_path.clone(); + let blocked_path_for_hook = blocked_path.clone(); + let script_contents_for_hook = script_contents.clone(); + + let permissions_yaml = format!( + "permissions:\n file_system:\n write:\n - \"{}\"\n", + allowed_dir.display() + ); + + let workspace_write_policy = restrictive_workspace_write_policy(); + let server = start_mock_server().await; + let test = build_zsh_fork_test( + &server, + runtime, + AskForApproval::OnRequest, + workspace_write_policy.clone(), + move |home| { + let _ = fs::remove_file(&allowed_path_for_hook); + let _ = fs::remove_file(&blocked_path_for_hook); + fs::create_dir_all(&allowed_dir_for_hook).unwrap(); + fs::create_dir_all(blocked_path_for_hook.parent().unwrap()).unwrap(); + write_skill_with_shell_script_contents( + home, + "mbolin-test-skill", + "sandboxed.sh", + &script_contents_for_hook, + ) + .unwrap(); + write_skill_metadata(home, "mbolin-test-skill", &permissions_yaml).unwrap(); + }, + ) + .await?; + + let (script_path_str, command) = skill_script_command(&test, "sandboxed.sh")?; + + let first_call_id = "zsh-fork-skill-permissions-1"; + let first_arguments = shell_command_arguments(&command)?; + let first_mocks = mount_function_call_agent_response( + &server, + first_call_id, + &first_arguments, + "shell_command", + ) + .await; + + submit_turn_with_policies( + &test, + "use $mbolin-test-skill", + AskForApproval::OnRequest, + workspace_write_policy.clone(), + ) + .await?; + + let maybe_approval = wait_for_exec_approval_request(&test).await; + let approval = match maybe_approval { + Some(approval) => approval, + None => panic!("expected exec approval request before completion"), + }; + assert_eq!(approval.call_id, first_call_id); + assert_eq!(approval.command, vec![script_path_str.clone()]); + assert_eq!( + approval.additional_permissions, + Some(PermissionProfile { + file_system: Some(FileSystemPermissions { + read: None, + write: Some(vec![allowed_dir.clone()]), + }), + ..Default::default() + }) + ); + + test.codex + .submit(Op::ExecApproval { + id: approval.effective_approval_id(), + turn_id: None, + decision: ReviewDecision::ApprovedForSession, + }) + .await?; + + wait_for_turn_complete(&test).await; + + let first_output = first_mocks + .completion + .single_request() + .function_call_output(first_call_id)["output"] + .as_str() + .unwrap_or_default() + .to_string(); + assert!( + first_output.contains("allowed"), + "expected skill sandbox to permit writes to the approved folder, got output: {first_output:?}" + ); + assert_eq!(fs::read_to_string(&allowed_path)?, "allowed"); + assert!( + !blocked_path.exists(), + "first run should not write outside the explicit skill sandbox" + ); + assert!( + !first_output.contains("blocked-created"), + "blocked path should not have been created: {first_output:?}" + ); + + let second_call_id = "zsh-fork-skill-permissions-2"; + let second_arguments = shell_command_arguments(&command)?; + let second_mocks = mount_function_call_agent_response( + &server, + second_call_id, + &second_arguments, + "shell_command", + ) + .await; + + let _ = fs::remove_file(&allowed_path); + let _ = fs::remove_file(&blocked_path); + + submit_turn_with_policies( + &test, + "use $mbolin-test-skill", + AskForApproval::OnRequest, + workspace_write_policy, + ) + .await?; + + let cached_approval = wait_for_exec_approval_request(&test).await; + assert!( + cached_approval.is_none(), + "expected second run to reuse the cached session approval" + ); + + let second_output = second_mocks + .completion + .single_request() + .function_call_output(second_call_id)["output"] + .as_str() + .unwrap_or_default() + .to_string(); + assert!( + second_output.contains("allowed"), + "expected cached skill approval to retain the explicit skill sandbox, got output: {second_output:?}" + ); + assert_eq!(fs::read_to_string(&allowed_path)?, "allowed"); + assert!( + !blocked_path.exists(), + "cached session approval should not widen skill execution beyond the explicit skill sandbox" + ); + assert!( + !second_output.contains("blocked-created"), + "blocked path should not have been created after cached approval: {second_output:?}" + ); + + Ok(()) +} + +/// This stays narrow on purpose: the important check is that `WorkspaceWrite` +/// continues to deny writes outside the workspace even under `zsh-fork`. +#[cfg(unix)] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn shell_zsh_fork_still_enforces_workspace_write_sandbox() -> Result<()> { + skip_if_no_network!(Ok(())); + + let Some(runtime) = zsh_fork_runtime("zsh-fork workspace sandbox test")? else { return Ok(()); }; let server = start_mock_server().await; let tool_call_id = "zsh-fork-workspace-write-deny"; let outside_path = "/tmp/codex-zsh-fork-workspace-write-deny.txt"; - let workspace_write_policy = SandboxPolicy::WorkspaceWrite { - writable_roots: Vec::new(), - read_only_access: Default::default(), - network_access: false, - exclude_tmpdir_env_var: true, - exclude_slash_tmp: true, - }; - let policy_for_config = workspace_write_policy.clone(); + let workspace_write_policy = restrictive_workspace_write_policy(); let _ = fs::remove_file(outside_path); - let mut builder = test_codex() - .with_pre_build_hook(move |_| { + let test = build_zsh_fork_test( + &server, + runtime, + AskForApproval::Never, + workspace_write_policy.clone(), + move |_| { let _ = fs::remove_file(outside_path); - }) - .with_config(move |config| { - config.features.enable(Feature::ShellTool); - config.features.enable(Feature::ShellZshFork); - config.zsh_path = Some(zsh_path.clone()); - config.main_execve_wrapper_exe = Some(main_execve_wrapper_exe); - config.permissions.allow_login_shell = false; - config.permissions.approval_policy = Constrained::allow_any(AskForApproval::Never); - config.permissions.sandbox_policy = Constrained::allow_any(policy_for_config); - }); - let test = builder.build(&server).await?; + }, + ) + .await?; let command = format!("touch {outside_path}"); let arguments = shell_command_arguments(&command)?; @@ -335,7 +719,7 @@ async fn shell_zsh_fork_still_enforces_workspace_write_sandbox() -> Result<()> { ) .await?; - wait_for_turn_complete_without_skill_approval(&test).await; + wait_for_turn_complete(&test).await; let call_output = mocks .completion @@ -343,9 +727,7 @@ async fn shell_zsh_fork_still_enforces_workspace_write_sandbox() -> Result<()> { .function_call_output(tool_call_id); let output = call_output["output"].as_str().unwrap_or_default(); assert!( - output.contains("Permission denied") - || output.contains("Operation not permitted") - || output.contains("Read-only file system"), + output_shows_sandbox_denial(output), "expected sandbox denial, got output: {output:?}" ); assert!( diff --git a/codex-rs/core/tests/suite/skills.rs b/codex-rs/core/tests/suite/skills.rs index f681ca1b64f..2f786590b18 100644 --- a/codex-rs/core/tests/suite/skills.rs +++ b/codex-rs/core/tests/suite/skills.rs @@ -77,7 +77,7 @@ async fn user_turn_includes_skill_instructions() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: codex_protocol::config_types::ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/snapshots/all__suite__model_visible_layout__model_visible_layout_environment_context_includes_one_subagent.snap b/codex-rs/core/tests/suite/snapshots/all__suite__model_visible_layout__model_visible_layout_environment_context_includes_one_subagent.snap new file mode 100644 index 00000000000..3436943cd29 --- /dev/null +++ b/codex-rs/core/tests/suite/snapshots/all__suite__model_visible_layout__model_visible_layout_environment_context_includes_one_subagent.snap @@ -0,0 +1,6 @@ +--- +source: core/tests/suite/model_visible_layout.rs +assertion_line: 476 +expression: "format_environment_context_subagents_snapshot(&[\"- agent-1: Atlas\"])" +--- +00:message/user::subagents=1> diff --git a/codex-rs/core/tests/suite/snapshots/all__suite__model_visible_layout__model_visible_layout_environment_context_includes_two_subagents.snap b/codex-rs/core/tests/suite/snapshots/all__suite__model_visible_layout__model_visible_layout_environment_context_includes_two_subagents.snap new file mode 100644 index 00000000000..105c28515b2 --- /dev/null +++ b/codex-rs/core/tests/suite/snapshots/all__suite__model_visible_layout__model_visible_layout_environment_context_includes_two_subagents.snap @@ -0,0 +1,6 @@ +--- +source: core/tests/suite/model_visible_layout.rs +assertion_line: 486 +expression: "format_environment_context_subagents_snapshot(&[\"- agent-1: Atlas\",\n\"- agent-2: Juniper\",])" +--- +00:message/user::subagents=2> diff --git a/codex-rs/core/tests/suite/subagent_notifications.rs b/codex-rs/core/tests/suite/subagent_notifications.rs index 422510f3286..3dc463c4abf 100644 --- a/codex-rs/core/tests/suite/subagent_notifications.rs +++ b/codex-rs/core/tests/suite/subagent_notifications.rs @@ -20,6 +20,8 @@ use tokio::time::sleep; use wiremock::MockServer; const SPAWN_CALL_ID: &str = "spawn-call-1"; +const FORKED_SPAWN_AGENT_OUTPUT_MESSAGE: &str = "You are the newly spawned agent. The prior conversation history was forked from your parent agent. Treat the next user message as your new task, and use the forked history only as background context."; +const TURN_0_FORK_PROMPT: &str = "seed fork context"; const TURN_1_PROMPT: &str = "spawn a child and continue"; const TURN_2_NO_WAIT_PROMPT: &str = "follow up without wait"; const CHILD_PROMPT: &str = "child: do work"; @@ -194,3 +196,116 @@ async fn subagent_notification_is_included_without_wait() -> Result<()> { Ok(()) } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn spawned_child_receives_forked_parent_context() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + + let seed_turn = mount_sse_once_match( + &server, + |req: &wiremock::Request| body_contains(req, TURN_0_FORK_PROMPT), + sse(vec![ + ev_response_created("resp-seed-1"), + ev_assistant_message("msg-seed-1", "seeded"), + ev_completed("resp-seed-1"), + ]), + ) + .await; + + let spawn_args = serde_json::to_string(&json!({ + "message": CHILD_PROMPT, + "fork_context": true, + }))?; + let spawn_turn = mount_sse_once_match( + &server, + |req: &wiremock::Request| body_contains(req, TURN_1_PROMPT), + sse(vec![ + ev_response_created("resp-turn1-1"), + ev_function_call(SPAWN_CALL_ID, "spawn_agent", &spawn_args), + ev_completed("resp-turn1-1"), + ]), + ) + .await; + + let _child_request_log = mount_sse_once_match( + &server, + |req: &wiremock::Request| body_contains(req, CHILD_PROMPT), + sse(vec![ + ev_response_created("resp-child-1"), + ev_assistant_message("msg-child-1", "child done"), + ev_completed("resp-child-1"), + ]), + ) + .await; + + let _turn1_followup = mount_sse_once_match( + &server, + |req: &wiremock::Request| body_contains(req, SPAWN_CALL_ID), + sse(vec![ + ev_response_created("resp-turn1-2"), + ev_assistant_message("msg-turn1-2", "parent done"), + ev_completed("resp-turn1-2"), + ]), + ) + .await; + + let mut builder = test_codex().with_config(|config| { + config.features.enable(Feature::Collab); + }); + let test = builder.build(&server).await?; + + test.submit_turn(TURN_0_FORK_PROMPT).await?; + let _ = seed_turn.single_request(); + + test.submit_turn(TURN_1_PROMPT).await?; + let _ = spawn_turn.single_request(); + + let deadline = Instant::now() + Duration::from_secs(2); + let child_request = loop { + if let Some(request) = server + .received_requests() + .await + .unwrap_or_default() + .into_iter() + .find(|request| { + body_contains(request, CHILD_PROMPT) + && body_contains(request, FORKED_SPAWN_AGENT_OUTPUT_MESSAGE) + }) + { + break request; + } + if Instant::now() >= deadline { + anyhow::bail!("timed out waiting for forked child request"); + } + sleep(Duration::from_millis(10)).await; + }; + assert!(body_contains(&child_request, TURN_0_FORK_PROMPT)); + assert!(body_contains(&child_request, "seeded")); + + let child_body = child_request + .body_json::() + .expect("forked child request body should be json"); + let function_call_output = child_body["input"] + .as_array() + .and_then(|items| { + items.iter().find(|item| { + item["type"].as_str() == Some("function_call_output") + && item["call_id"].as_str() == Some(SPAWN_CALL_ID) + }) + }) + .unwrap_or_else(|| panic!("expected forked child request to include spawn_agent output")); + let (content, success) = match &function_call_output["output"] { + serde_json::Value::String(text) => (Some(text.as_str()), None), + serde_json::Value::Object(output) => ( + output.get("content").and_then(serde_json::Value::as_str), + output.get("success").and_then(serde_json::Value::as_bool), + ), + _ => (None, None), + }; + assert_eq!(content, Some(FORKED_SPAWN_AGENT_OUTPUT_MESSAGE)); + assert_ne!(success, Some(false)); + + Ok(()) +} diff --git a/codex-rs/core/tests/suite/tool_harness.rs b/codex-rs/core/tests/suite/tool_harness.rs index 507c8eb0683..13191892cef 100644 --- a/codex-rs/core/tests/suite/tool_harness.rs +++ b/codex-rs/core/tests/suite/tool_harness.rs @@ -4,7 +4,6 @@ use std::fs; use assert_matches::assert_matches; use codex_core::features::Feature; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::plan_tool::StepStatus; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; @@ -89,7 +88,7 @@ async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()> sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -158,7 +157,7 @@ async fn update_plan_tool_emits_plan_update_event() -> anyhow::Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -237,7 +236,7 @@ async fn update_plan_tool_rejects_malformed_payload() -> anyhow::Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -328,7 +327,7 @@ async fn apply_patch_tool_executes_and_emits_patch_events() -> anyhow::Result<() sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -427,7 +426,7 @@ async fn apply_patch_reports_parse_diagnostics() -> anyhow::Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/tool_parallelism.rs b/codex-rs/core/tests/suite/tool_parallelism.rs index 357dcff0fd0..1678e58349b 100644 --- a/codex-rs/core/tests/suite/tool_parallelism.rs +++ b/codex-rs/core/tests/suite/tool_parallelism.rs @@ -5,7 +5,6 @@ use std::fs; use std::time::Duration; use std::time::Instant; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::Op; @@ -46,7 +45,7 @@ async fn run_turn(test: &TestCodex, prompt: &str) -> anyhow::Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -361,7 +360,7 @@ async fn shell_tools_start_before_response_completed_when_stream_delayed() -> an sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/truncation.rs b/codex-rs/core/tests/suite/truncation.rs index 8ede6c621f6..6ba6e53b1b8 100644 --- a/codex-rs/core/tests/suite/truncation.rs +++ b/codex-rs/core/tests/suite/truncation.rs @@ -5,7 +5,6 @@ use anyhow::Context; use anyhow::Result; use codex_core::config::types::McpServerConfig; use codex_core::config::types::McpServerTransportConfig; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::Op; @@ -371,6 +370,7 @@ async fn mcp_tool_call_output_exceeds_limit_truncated_for_model() -> Result<()> enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, ); config @@ -465,6 +465,7 @@ async fn mcp_image_output_preserves_image_and_no_text_summary() -> Result<()> { enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, ); config @@ -488,7 +489,7 @@ async fn mcp_image_output_preserves_image_and_no_text_summary() -> Result<()> { sandbox_policy: SandboxPolicy::new_read_only_policy(), model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -730,6 +731,7 @@ async fn mcp_tool_call_output_not_truncated_with_custom_limit() -> Result<()> { enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }, ); config diff --git a/codex-rs/core/tests/suite/unified_exec.rs b/codex-rs/core/tests/suite/unified_exec.rs index 705e4156203..55bc19e6809 100644 --- a/codex-rs/core/tests/suite/unified_exec.rs +++ b/codex-rs/core/tests/suite/unified_exec.rs @@ -6,7 +6,6 @@ use std::sync::OnceLock; use anyhow::Context; use anyhow::Result; use codex_core::features::Feature; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::ExecCommandSource; @@ -209,7 +208,7 @@ async fn unified_exec_intercepts_apply_patch_exec_command() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -338,7 +337,7 @@ async fn unified_exec_emits_exec_command_begin_event() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -416,7 +415,7 @@ async fn unified_exec_resolves_relative_workdir() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -497,7 +496,7 @@ async fn unified_exec_respects_workdir_override() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -590,7 +589,7 @@ async fn unified_exec_emits_exec_command_end_event() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -665,7 +664,7 @@ async fn unified_exec_emits_output_delta_for_exec_command() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -741,7 +740,7 @@ async fn unified_exec_full_lifecycle_with_background_end_event() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -871,7 +870,7 @@ async fn unified_exec_emits_terminal_interaction_for_write_stdin() -> Result<()> sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -1008,7 +1007,7 @@ async fn unified_exec_terminal_interaction_captures_delayed_output() -> Result<( sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -1168,7 +1167,7 @@ async fn unified_exec_emits_one_begin_and_one_end_event() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -1266,7 +1265,7 @@ async fn exec_command_reports_chunk_and_exit_metadata() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -1384,7 +1383,7 @@ async fn unified_exec_defaults_to_pipe() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -1474,7 +1473,7 @@ async fn unified_exec_can_enable_tty() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -1555,7 +1554,7 @@ async fn unified_exec_respects_early_exit_notifications() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -1686,7 +1685,7 @@ async fn write_stdin_returns_exit_metadata_and_clears_session() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -1854,7 +1853,7 @@ async fn unified_exec_emits_end_event_when_session_dies_via_stdin() -> Result<() sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -1931,7 +1930,7 @@ async fn unified_exec_keeps_long_running_session_after_turn_end() -> Result<()> sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -2019,7 +2018,7 @@ async fn unified_exec_interrupt_terminates_long_running_session() -> Result<()> sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -2116,7 +2115,7 @@ async fn unified_exec_reuses_session_via_stdin() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -2251,7 +2250,7 @@ PY sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -2365,7 +2364,7 @@ async fn unified_exec_timeout_and_followup_poll() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -2461,7 +2460,7 @@ PY sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -2543,7 +2542,7 @@ async fn unified_exec_runs_under_sandbox() -> Result<()> { sandbox_policy: SandboxPolicy::new_read_only_policy(), model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -2647,7 +2646,7 @@ async fn unified_exec_python_prompt_under_seatbelt() -> Result<()> { sandbox_policy: SandboxPolicy::new_read_only_policy(), model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -2742,7 +2741,7 @@ async fn unified_exec_runs_on_all_platforms() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -2877,7 +2876,7 @@ async fn unified_exec_prunes_exited_sessions_first() -> Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/user_shell_cmd.rs b/codex-rs/core/tests/suite/user_shell_cmd.rs index c38b86e4436..766d79abb85 100644 --- a/codex-rs/core/tests/suite/user_shell_cmd.rs +++ b/codex-rs/core/tests/suite/user_shell_cmd.rs @@ -1,6 +1,5 @@ use anyhow::Context; use codex_core::features::Feature; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::ExecCommandEndEvent; @@ -178,7 +177,7 @@ async fn user_shell_command_does_not_replace_active_turn() -> anyhow::Result<()> sandbox_policy: SandboxPolicy::DangerFullAccess, model: fixture.session_configured.model.clone(), effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/view_image.rs b/codex-rs/core/tests/suite/view_image.rs index 36ccf1a25e6..4521ed45f8d 100644 --- a/codex-rs/core/tests/suite/view_image.rs +++ b/codex-rs/core/tests/suite/view_image.rs @@ -112,7 +112,7 @@ async fn user_turn_with_local_image_attaches_image() -> anyhow::Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -214,7 +214,7 @@ async fn view_image_tool_attaches_local_image() -> anyhow::Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -344,7 +344,7 @@ console.log(out.output?.body?.text ?? ""); sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -358,40 +358,26 @@ console.log(out.output?.body?.text ?? ""); .await; let req = mock.single_request(); - let (js_repl_output, js_repl_success) = req - .custom_tool_call_output_content_and_success(call_id) - .expect("custom tool output present"); - let js_repl_output = js_repl_output.expect("custom tool output text present"); - assert_ne!( - js_repl_success, - Some(false), - "js_repl call failed unexpectedly: {js_repl_output}" - ); - let body = req.body_json(); - let image_messages = image_messages(&body); assert_eq!( - image_messages.len(), - 1, - "js_repl view_image should inject exactly one pending input image message" + image_messages(&body).len(), + 0, + "js_repl view_image should not inject a pending input image message" ); - let image_message = image_messages - .into_iter() - .next() - .expect("pending input image message not included in request"); - let image_url = image_message - .get("content") + + let custom_output = req.custom_tool_call_output(call_id); + let output_items = custom_output + .get("output") .and_then(Value::as_array) - .and_then(|content| { - content.iter().find_map(|span| { - if span.get("type").and_then(Value::as_str) == Some("input_image") { - span.get("image_url").and_then(Value::as_str) - } else { - None - } - }) + .expect("custom_tool_call_output should be a content item array"); + let image_url = output_items + .iter() + .find_map(|item| { + (item.get("type").and_then(Value::as_str) == Some("input_image")) + .then(|| item.get("image_url").and_then(Value::as_str)) + .flatten() }) - .expect("image_url present"); + .expect("image_url present in js_repl custom tool output"); assert!( image_url.starts_with("data:image/png;base64,"), "expected png data URL, got {image_url}" @@ -447,7 +433,7 @@ async fn view_image_tool_errors_when_path_is_directory() -> anyhow::Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -522,7 +508,7 @@ async fn view_image_tool_placeholder_for_non_image_files() -> anyhow::Result<()> sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -614,7 +600,7 @@ async fn view_image_tool_errors_when_file_missing() -> anyhow::Result<()> { sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -675,8 +661,10 @@ async fn view_image_tool_returns_unsupported_message_for_text_only_model() -> an base_instructions: "base instructions".to_string(), model_messages: None, supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::Auto, support_verbosity: false, default_verbosity: None, + availability_nux: None, apply_patch_tool_type: None, truncation_policy: TruncationPolicyConfig::bytes(10_000), supports_parallel_tool_calls: false, @@ -736,7 +724,7 @@ async fn view_image_tool_returns_unsupported_message_for_text_only_model() -> an sandbox_policy: SandboxPolicy::DangerFullAccess, model: model_slug.to_string(), effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) @@ -812,7 +800,7 @@ async fn replaces_invalid_local_image_after_bad_request() -> anyhow::Result<()> sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_model, effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/core/tests/suite/websocket_fallback.rs b/codex-rs/core/tests/suite/websocket_fallback.rs index 20c7b3b2f25..9ff5cbe0013 100644 --- a/codex-rs/core/tests/suite/websocket_fallback.rs +++ b/codex-rs/core/tests/suite/websocket_fallback.rs @@ -1,6 +1,5 @@ use anyhow::Result; use codex_core::features::Feature; -use codex_protocol::config_types::ReasoningSummary; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::Op; @@ -161,7 +160,7 @@ async fn websocket_fallback_hides_first_websocket_retry_stream_error() -> Result sandbox_policy: SandboxPolicy::DangerFullAccess, model: session_configured.model.clone(), effort: None, - summary: ReasoningSummary::Auto, + summary: None, collaboration_mode: None, personality: None, }) diff --git a/codex-rs/docs/codex_mcp_interface.md b/codex-rs/docs/codex_mcp_interface.md index 3a3abe9a9ec..763f2602ad7 100644 --- a/codex-rs/docs/codex_mcp_interface.md +++ b/codex-rs/docs/codex_mcp_interface.md @@ -90,20 +90,26 @@ directory (it returns the restored thread summary). Fetch the catalog of models available in the current Codex build with `model/list`. The request accepts optional pagination inputs: -- `pageSize` – number of models to return (defaults to a server-selected value) +- `limit` – number of models to return (defaults to a server-selected value) - `cursor` – opaque string from the previous response’s `nextCursor` Each response yields: -- `items` – ordered list of models. A model includes: +- `data` – ordered list of models. A model includes: - `id`, `model`, `displayName`, `description` - `supportedReasoningEfforts` – array of objects with: - - `reasoningEffort` – one of `minimal|low|medium|high` + - `reasoningEffort` – one of `none|minimal|low|medium|high|xhigh` - `description` – human-friendly label for the effort - `defaultReasoningEffort` – suggested effort for the UI + - `inputModalities` – accepted input types for the model - `supportsPersonality` – whether the model supports personality-specific instructions - `isDefault` – whether the model is recommended for most users - `upgrade` – optional recommended upgrade model id + - `upgradeInfo` – optional upgrade metadata object with: + - `model` – recommended upgrade model id + - `upgradeCopy` – optional display copy for the upgrade recommendation + - `modelLink` – optional link for the upgrade recommendation + - `migrationMarkdown` – optional markdown shown when presenting the upgrade - `nextCursor` – pass into the next request to continue paging (optional) ## Collaboration modes (experimental) diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index 1dca110aff6..e8ab6f8af74 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -377,7 +377,6 @@ pub async fn run_main(cli: Cli, arg0_paths: Arg0DispatchPaths) -> anyhow::Result let default_approval_policy = config.permissions.approval_policy.value(); let default_sandbox_policy = config.permissions.sandbox_policy.get(); let default_effort = config.model_reasoning_effort; - let default_summary = config.model_reasoning_summary; // When --yolo (dangerously_bypass_approvals_and_sandbox) is set, also skip the git repo check // since the user is explicitly running in an externally sandboxed environment. @@ -560,7 +559,7 @@ pub async fn run_main(cli: Cli, arg0_paths: Arg0DispatchPaths) -> anyhow::Result sandbox_policy: default_sandbox_policy.clone(), model: default_model, effort: default_effort, - summary: default_summary, + summary: None, final_output_json_schema: output_schema, collaboration_mode: None, personality: None, diff --git a/codex-rs/hooks/src/registry.rs b/codex-rs/hooks/src/registry.rs index 6568f5374f8..1648fa04621 100644 --- a/codex-rs/hooks/src/registry.rs +++ b/codex-rs/hooks/src/registry.rs @@ -104,6 +104,7 @@ mod tests { HookPayload { session_id: ThreadId::new(), cwd: PathBuf::from(CWD), + client: None, triggered_at: Utc .with_ymd_and_hms(2025, 1, 1, 0, 0, 0) .single() @@ -172,6 +173,7 @@ mod tests { HookPayload { session_id: ThreadId::new(), cwd: PathBuf::from(CWD), + client: None, triggered_at: Utc .with_ymd_and_hms(2025, 1, 1, 0, 0, 0) .single() diff --git a/codex-rs/hooks/src/types.rs b/codex-rs/hooks/src/types.rs index 5131580128d..b7ef3c2141e 100644 --- a/codex-rs/hooks/src/types.rs +++ b/codex-rs/hooks/src/types.rs @@ -65,6 +65,8 @@ impl Hook { pub struct HookPayload { pub session_id: ThreadId, pub cwd: PathBuf, + #[serde(skip_serializing_if = "Option::is_none")] + pub client: Option, #[serde(serialize_with = "serialize_triggered_at")] pub triggered_at: DateTime, pub hook_event: HookEvent, @@ -181,6 +183,7 @@ mod tests { let payload = HookPayload { session_id, cwd: PathBuf::from("tmp"), + client: None, triggered_at: Utc .with_ymd_and_hms(2025, 1, 1, 0, 0, 0) .single() @@ -218,6 +221,7 @@ mod tests { let payload = HookPayload { session_id, cwd: PathBuf::from("tmp"), + client: None, triggered_at: Utc .with_ymd_and_hms(2025, 1, 1, 0, 0, 0) .single() diff --git a/codex-rs/hooks/src/user_notification.rs b/codex-rs/hooks/src/user_notification.rs index 50472225e4b..caeca59b509 100644 --- a/codex-rs/hooks/src/user_notification.rs +++ b/codex-rs/hooks/src/user_notification.rs @@ -1,4 +1,3 @@ -use std::path::Path; use std::process::Stdio; use std::sync::Arc; @@ -19,6 +18,8 @@ enum UserNotification { thread_id: String, turn_id: String, cwd: String, + #[serde(skip_serializing_if = "Option::is_none")] + client: Option, /// Messages that the user sent to the agent to initiate the turn. input_messages: Vec, @@ -28,13 +29,14 @@ enum UserNotification { }, } -pub fn legacy_notify_json(hook_event: &HookEvent, cwd: &Path) -> Result { - match hook_event { +pub fn legacy_notify_json(payload: &HookPayload) -> Result { + match &payload.hook_event { HookEvent::AfterAgent { event } => { serde_json::to_string(&UserNotification::AgentTurnComplete { thread_id: event.thread_id.to_string(), turn_id: event.turn_id.clone(), - cwd: cwd.display().to_string(), + cwd: payload.cwd.display().to_string(), + client: payload.client.clone(), input_messages: event.input_messages.clone(), last_assistant_message: event.last_assistant_message.clone(), }) @@ -56,7 +58,7 @@ pub fn notify_hook(argv: Vec) -> Hook { Some(command) => command, None => return HookResult::Success, }; - if let Ok(notify_payload) = legacy_notify_json(&payload.hook_event, &payload.cwd) { + if let Ok(notify_payload) = legacy_notify_json(payload) { command.arg(notify_payload); } @@ -91,6 +93,7 @@ mod tests { "thread-id": "b5f6c1c2-1111-2222-3333-444455556666", "turn-id": "12345", "cwd": "/Users/example/project", + "client": "codex-tui", "input-messages": ["Rename `foo` to `bar` and update the callsites."], "last-assistant-message": "Rename complete and verified `cargo build` succeeds.", }) @@ -102,6 +105,7 @@ mod tests { thread_id: "b5f6c1c2-1111-2222-3333-444455556666".to_string(), turn_id: "12345".to_string(), cwd: "/Users/example/project".to_string(), + client: Some("codex-tui".to_string()), input_messages: vec!["Rename `foo` to `bar` and update the callsites.".to_string()], last_assistant_message: Some( "Rename complete and verified `cargo build` succeeds.".to_string(), @@ -115,19 +119,27 @@ mod tests { #[test] fn legacy_notify_json_matches_historical_wire_shape() -> Result<()> { - let hook_event = HookEvent::AfterAgent { - event: crate::HookEventAfterAgent { - thread_id: ThreadId::from_string("b5f6c1c2-1111-2222-3333-444455556666") - .expect("valid thread id"), - turn_id: "12345".to_string(), - input_messages: vec!["Rename `foo` to `bar` and update the callsites.".to_string()], - last_assistant_message: Some( - "Rename complete and verified `cargo build` succeeds.".to_string(), - ), + let payload = HookPayload { + session_id: ThreadId::new(), + cwd: std::path::Path::new("/Users/example/project").to_path_buf(), + client: Some("codex-tui".to_string()), + triggered_at: chrono::Utc::now(), + hook_event: HookEvent::AfterAgent { + event: crate::HookEventAfterAgent { + thread_id: ThreadId::from_string("b5f6c1c2-1111-2222-3333-444455556666") + .expect("valid thread id"), + turn_id: "12345".to_string(), + input_messages: vec![ + "Rename `foo` to `bar` and update the callsites.".to_string(), + ], + last_assistant_message: Some( + "Rename complete and verified `cargo build` succeeds.".to_string(), + ), + }, }, }; - let serialized = legacy_notify_json(&hook_event, Path::new("/Users/example/project"))?; + let serialized = legacy_notify_json(&payload)?; let actual: Value = serde_json::from_str(&serialized)?; assert_eq!(actual, expected_notification_json()); diff --git a/codex-rs/protocol/src/approvals.rs b/codex-rs/protocol/src/approvals.rs index 98fe0280e3d..12bd65ec598 100644 --- a/codex-rs/protocol/src/approvals.rs +++ b/codex-rs/protocol/src/approvals.rs @@ -2,15 +2,30 @@ use std::collections::HashMap; use std::path::PathBuf; use crate::mcp::RequestId; +use crate::models::MacOsSeatbeltProfileExtensions; use crate::models::PermissionProfile; use crate::parse_command::ParsedCommand; use crate::protocol::FileChange; use crate::protocol::ReviewDecision; +use crate::protocol::SandboxPolicy; use schemars::JsonSchema; use serde::Deserialize; use serde::Serialize; use ts_rs::TS; +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Permissions { + pub sandbox_policy: SandboxPolicy, + pub macos_seatbelt_profile_extensions: Option, +} + +#[allow(clippy::large_enum_variant)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum EscalationPermissions { + PermissionProfile(PermissionProfile), + Permissions(Permissions), +} + /// Proposed execpolicy change to allow commands starting with this prefix. /// /// The `command` tokens form the prefix that would be added as an execpolicy diff --git a/codex-rs/protocol/src/models.rs b/codex-rs/protocol/src/models.rs index ff5030c75eb..14cb9682813 100644 --- a/codex-rs/protocol/src/models.rs +++ b/codex-rs/protocol/src/models.rs @@ -95,6 +95,32 @@ pub enum MacOsAutomationValue { BundleIds(Vec), } +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub enum MacOsPreferencesPermission { + // IMPORTANT: ReadOnly needs to be the default because it's the + // security-sensitive default and keeps cf prefs working. + #[default] + ReadOnly, + ReadWrite, + None, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub enum MacOsAutomationPermission { + #[default] + None, + All, + BundleIds(Vec), +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct MacOsSeatbeltProfileExtensions { + pub macos_preferences: MacOsPreferencesPermission, + pub macos_automation: MacOsAutomationPermission, + pub macos_accessibility: bool, + pub macos_calendar: bool, +} + #[derive(Debug, Clone, Default, Eq, Hash, PartialEq, Serialize, Deserialize, JsonSchema, TS)] pub struct PermissionProfile { pub network: Option, @@ -135,7 +161,7 @@ pub enum ResponseInputItem { }, CustomToolCallOutput { call_id: String, - output: String, + output: FunctionCallOutputPayload, }, } @@ -235,9 +261,12 @@ pub enum ResponseItem { name: String, input: String, }, + // `custom_tool_call_output.output` uses the same wire encoding as + // `function_call_output.output` so freeform tools can return either plain + // text or structured content items. CustomToolCallOutput { call_id: String, - output: String, + output: FunctionCallOutputPayload, }, // Emitted by the Responses API when the agent triggers a web search. // Example payload (from SSE `response.output_item.done`): @@ -1512,6 +1541,26 @@ mod tests { Ok(()) } + #[test] + fn serializes_custom_tool_image_outputs_as_array() -> Result<()> { + let item = ResponseInputItem::CustomToolCallOutput { + call_id: "call1".into(), + output: FunctionCallOutputPayload::from_content_items(vec![ + FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,BASE64".into(), + }, + ]), + }; + + let json = serde_json::to_string(&item)?; + let v: serde_json::Value = serde_json::from_str(&json)?; + + let output = v.get("output").expect("output field"); + assert!(output.is_array(), "expected array output"); + + Ok(()) + } + #[test] fn preserves_existing_image_data_urls() -> Result<()> { let call_tool_result = CallToolResult { diff --git a/codex-rs/protocol/src/openai_models.rs b/codex-rs/protocol/src/openai_models.rs index 8d01e10d0b9..b63f137442d 100644 --- a/codex-rs/protocol/src/openai_models.rs +++ b/codex-rs/protocol/src/openai_models.rs @@ -15,6 +15,7 @@ use tracing::warn; use ts_rs::TS; use crate::config_types::Personality; +use crate::config_types::ReasoningSummary; use crate::config_types::Verbosity; const PERSONALITY_PLACEHOLDER: &str = "{{ personality }}"; @@ -98,6 +99,11 @@ pub struct ModelUpgrade { pub migration_markdown: Option, } +#[derive(Debug, Clone, Deserialize, Serialize, TS, JsonSchema, PartialEq, Eq)] +pub struct ModelAvailabilityNux { + pub message: String, +} + /// Metadata describing a Codex-supported model. #[derive(Debug, Clone, Deserialize, Serialize, TS, JsonSchema, PartialEq)] pub struct ModelPreset { @@ -122,6 +128,8 @@ pub struct ModelPreset { pub upgrade: Option, /// Whether this preset should appear in the picker UI. pub show_in_picker: bool, + /// Availability NUX shown when this preset becomes accessible to the user. + pub availability_nux: Option, /// whether this model is supported in the api pub supported_in_api: bool, /// Input modalities accepted when composing user turns for this preset. @@ -224,11 +232,14 @@ pub struct ModelInfo { pub visibility: ModelVisibility, pub supported_in_api: bool, pub priority: i32, + pub availability_nux: Option, pub upgrade: Option, pub base_instructions: String, #[serde(default, skip_serializing_if = "Option::is_none")] pub model_messages: Option, pub supports_reasoning_summaries: bool, + #[serde(default)] + pub default_reasoning_summary: ReasoningSummary, pub support_verbosity: bool, pub default_verbosity: Option, pub apply_patch_tool_type: Option, @@ -407,6 +418,7 @@ impl From for ModelPreset { migration_markdown: Some(upgrade.migration_markdown.clone()), }), show_in_picker: info.visibility == ModelVisibility::List, + availability_nux: info.availability_nux, supported_in_api: info.supported_in_api, input_modalities: info.input_modalities, } @@ -492,10 +504,12 @@ mod tests { visibility: ModelVisibility::List, supported_in_api: true, priority: 1, + availability_nux: None, upgrade: None, base_instructions: "base".to_string(), model_messages: spec, supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::Auto, support_verbosity: false, default_verbosity: None, apply_patch_tool_type: None, @@ -664,4 +678,57 @@ mod tests { ); assert_eq!(personality_variables.get_personality_message(None), None); } + + #[test] + fn model_info_defaults_availability_nux_to_none_when_omitted() { + let model: ModelInfo = serde_json::from_value(serde_json::json!({ + "slug": "test-model", + "display_name": "Test Model", + "description": null, + "supported_reasoning_levels": [], + "shell_type": "shell_command", + "visibility": "list", + "supported_in_api": true, + "priority": 1, + "upgrade": null, + "base_instructions": "base", + "model_messages": null, + "supports_reasoning_summaries": false, + "default_reasoning_summary": "auto", + "support_verbosity": false, + "default_verbosity": null, + "apply_patch_tool_type": null, + "truncation_policy": { + "mode": "bytes", + "limit": 10000 + }, + "supports_parallel_tool_calls": false, + "context_window": null, + "auto_compact_token_limit": null, + "effective_context_window_percent": 95, + "experimental_supported_tools": [], + "input_modalities": ["text", "image"], + "prefer_websockets": false + })) + .expect("deserialize model info"); + + assert_eq!(model.availability_nux, None); + } + + #[test] + fn model_preset_preserves_availability_nux() { + let preset = ModelPreset::from(ModelInfo { + availability_nux: Some(ModelAvailabilityNux { + message: "Try Spark.".to_string(), + }), + ..test_model(None) + }); + + assert_eq!( + preset.availability_nux, + Some(ModelAvailabilityNux { + message: "Try Spark.".to_string(), + }) + ); + } } diff --git a/codex-rs/protocol/src/protocol.rs b/codex-rs/protocol/src/protocol.rs index a401cddb3eb..b5aaf47ea65 100644 --- a/codex-rs/protocol/src/protocol.rs +++ b/codex-rs/protocol/src/protocol.rs @@ -185,7 +185,11 @@ pub enum Op { effort: Option, /// Will only be honored if the model is configured to use reasoning. - summary: ReasoningSummaryConfig, + /// + /// When omitted, the session keeps the current setting (which allows core to + /// fall back to the selected model's default on new sessions). + #[serde(default, skip_serializing_if = "Option::is_none")] + summary: Option, // The JSON schema to use for the final assistant message final_output_json_schema: Option, @@ -2129,6 +2133,10 @@ pub struct TurnContextItem { #[serde(default, skip_serializing_if = "Option::is_none")] pub turn_id: Option, pub cwd: PathBuf, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub current_date: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub timezone: Option, pub approval_policy: AskForApproval, pub sandbox_policy: SandboxPolicy, #[serde(skip_serializing_if = "Option::is_none")] @@ -3357,6 +3365,8 @@ mod tests { let item = TurnContextItem { turn_id: None, cwd: PathBuf::from("/tmp"), + current_date: None, + timezone: None, approval_policy: AskForApproval::Never, sandbox_policy: SandboxPolicy::DangerFullAccess, network: Some(TurnContextNetworkItem { diff --git a/codex-rs/rmcp-client/src/perform_oauth_login.rs b/codex-rs/rmcp-client/src/perform_oauth_login.rs index 62a9c3b019f..c71799c6293 100644 --- a/codex-rs/rmcp-client/src/perform_oauth_login.rs +++ b/codex-rs/rmcp-client/src/perform_oauth_login.rs @@ -47,6 +47,7 @@ pub async fn perform_oauth_login( http_headers: Option>, env_http_headers: Option>, scopes: &[String], + oauth_resource: Option<&str>, callback_port: Option, callback_url: Option<&str>, ) -> Result<()> { @@ -60,6 +61,7 @@ pub async fn perform_oauth_login( store_mode, headers, scopes, + oauth_resource, true, callback_port, callback_url, @@ -78,6 +80,7 @@ pub async fn perform_oauth_login_return_url( http_headers: Option>, env_http_headers: Option>, scopes: &[String], + oauth_resource: Option<&str>, timeout_secs: Option, callback_port: Option, callback_url: Option<&str>, @@ -92,6 +95,7 @@ pub async fn perform_oauth_login_return_url( store_mode, headers, scopes, + oauth_resource, false, callback_port, callback_url, @@ -303,6 +307,7 @@ impl OauthLoginFlow { store_mode: OAuthCredentialsStoreMode, headers: OauthHeaders, scopes: &[String], + oauth_resource: Option<&str>, launch_browser: bool, callback_port: Option, callback_url: Option<&str>, @@ -340,7 +345,11 @@ impl OauthLoginFlow { oauth_state .start_authorization(&scope_refs, &redirect_uri, Some("Codex")) .await?; - let auth_url = oauth_state.get_authorization_url().await?; + let auth_url = append_query_param( + &oauth_state.get_authorization_url().await?, + "resource", + oauth_resource, + ); let timeout_secs = timeout_secs.unwrap_or(DEFAULT_OAUTH_TIMEOUT_SECS).max(1); let timeout = Duration::from_secs(timeout_secs as u64); @@ -431,9 +440,29 @@ impl OauthLoginFlow { } } +fn append_query_param(url: &str, key: &str, value: Option<&str>) -> String { + let Some(value) = value else { + return url.to_string(); + }; + let value = value.trim(); + if value.is_empty() { + return url.to_string(); + } + if let Ok(mut parsed) = Url::parse(url) { + parsed.query_pairs_mut().append_pair(key, value); + return parsed.to_string(); + } + let encoded = urlencoding::encode(value); + let separator = if url.contains('?') { "&" } else { "?" }; + format!("{url}{separator}{key}={encoded}") +} + #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; + use super::CallbackOutcome; + use super::append_query_param; use super::callback_path_from_redirect_uri; use super::parse_oauth_callback; @@ -461,4 +490,36 @@ mod tests { .expect("redirect URI should parse"); assert_eq!(path, "/oauth/callback"); } + + #[test] + fn append_query_param_adds_resource_to_absolute_url() { + let url = append_query_param( + "https://example.com/authorize?scope=read", + "resource", + Some("https://api.example.com"), + ); + + assert_eq!( + url, + "https://example.com/authorize?scope=read&resource=https%3A%2F%2Fapi.example.com" + ); + } + + #[test] + fn append_query_param_ignores_empty_values() { + let url = append_query_param( + "https://example.com/authorize?scope=read", + "resource", + Some(" "), + ); + + assert_eq!(url, "https://example.com/authorize?scope=read"); + } + + #[test] + fn append_query_param_handles_unparseable_url() { + let url = append_query_param("not a url", "resource", Some("api/resource")); + + assert_eq!(url, "not a url?resource=api%2Fresource"); + } } diff --git a/codex-rs/shell-escalation/Cargo.toml b/codex-rs/shell-escalation/Cargo.toml index 85075cfb51e..fbc5bcd8c7b 100644 --- a/codex-rs/shell-escalation/Cargo.toml +++ b/codex-rs/shell-escalation/Cargo.toml @@ -12,6 +12,7 @@ path = "src/bin/main_execve_wrapper.rs" anyhow = { workspace = true } async-trait = { workspace = true } clap = { workspace = true, features = ["derive"] } +codex-protocol = { workspace = true } codex-utils-absolute-path = { workspace = true } libc = { workspace = true } serde = { workspace = true, features = ["derive"] } diff --git a/codex-rs/shell-escalation/src/lib.rs b/codex-rs/shell-escalation/src/lib.rs index 48bc1165811..1cc42a46db4 100644 --- a/codex-rs/shell-escalation/src/lib.rs +++ b/codex-rs/shell-escalation/src/lib.rs @@ -6,12 +6,22 @@ pub use unix::EscalateAction; #[cfg(unix)] pub use unix::EscalateServer; #[cfg(unix)] +pub use unix::EscalationDecision; +#[cfg(unix)] +pub use unix::EscalationExecution; +#[cfg(unix)] +pub use unix::EscalationPermissions; +#[cfg(unix)] pub use unix::EscalationPolicy; #[cfg(unix)] pub use unix::ExecParams; #[cfg(unix)] pub use unix::ExecResult; #[cfg(unix)] +pub use unix::Permissions; +#[cfg(unix)] +pub use unix::PreparedExec; +#[cfg(unix)] pub use unix::ShellCommandExecutor; #[cfg(unix)] pub use unix::Stopwatch; diff --git a/codex-rs/shell-escalation/src/unix/escalate_protocol.rs b/codex-rs/shell-escalation/src/unix/escalate_protocol.rs index 6002fdc57b1..ea38c7b2a6b 100644 --- a/codex-rs/shell-escalation/src/unix/escalate_protocol.rs +++ b/codex-rs/shell-escalation/src/unix/escalate_protocol.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::os::fd::RawFd; use std::path::PathBuf; +use codex_protocol::approvals::EscalationPermissions; use codex_utils_absolute_path::AbsolutePathBuf; use serde::Deserialize; use serde::Serialize; @@ -35,6 +36,38 @@ pub struct EscalateResponse { pub action: EscalateAction, } +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum EscalationDecision { + Run, + Escalate(EscalationExecution), + Deny { reason: Option }, +} + +#[allow(clippy::large_enum_variant)] +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum EscalationExecution { + /// Rerun the intercepted command outside any sandbox wrapper. + Unsandboxed, + /// Rerun using the turn's current sandbox configuration. + TurnDefault, + /// Rerun using an explicit sandbox configuration attached to the request. + Permissions(EscalationPermissions), +} + +impl EscalationDecision { + pub fn run() -> Self { + Self::Run + } + + pub fn escalate(execution: EscalationExecution) -> Self { + Self::Escalate(execution) + } + + pub fn deny(reason: Option) -> Self { + Self::Deny { reason } + } +} + #[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)] pub enum EscalateAction { /// The command should be run directly by the client. diff --git a/codex-rs/shell-escalation/src/unix/escalate_server.rs b/codex-rs/shell-escalation/src/unix/escalate_server.rs index 2951f584a6a..bb45c3ed718 100644 --- a/codex-rs/shell-escalation/src/unix/escalate_server.rs +++ b/codex-rs/shell-escalation/src/unix/escalate_server.rs @@ -15,6 +15,8 @@ use crate::unix::escalate_protocol::EXEC_WRAPPER_ENV_VAR; use crate::unix::escalate_protocol::EscalateAction; use crate::unix::escalate_protocol::EscalateRequest; use crate::unix::escalate_protocol::EscalateResponse; +use crate::unix::escalate_protocol::EscalationDecision; +use crate::unix::escalate_protocol::EscalationExecution; use crate::unix::escalate_protocol::LEGACY_BASH_EXEC_WRAPPER_ENV_VAR; use crate::unix::escalate_protocol::SuperExecMessage; use crate::unix::escalate_protocol::SuperExecResult; @@ -37,6 +39,16 @@ pub trait ShellCommandExecutor: Send + Sync { env: HashMap, cancel_rx: CancellationToken, ) -> anyhow::Result; + + /// Prepares an escalated subcommand for execution on the server side. + async fn prepare_escalated_exec( + &self, + program: &AbsolutePathBuf, + argv: &[String], + workdir: &AbsolutePathBuf, + env: HashMap, + execution: EscalationExecution, + ) -> anyhow::Result; } #[derive(Debug, serde::Deserialize, serde::Serialize)] @@ -62,6 +74,14 @@ pub struct ExecResult { pub timed_out: bool, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PreparedExec { + pub command: Vec, + pub cwd: PathBuf, + pub env: HashMap, + pub arg0: Option, +} + pub struct EscalateServer { bash_path: PathBuf, execve_wrapper: PathBuf, @@ -69,9 +89,9 @@ pub struct EscalateServer { } impl EscalateServer { - pub fn new

(bash_path: PathBuf, execve_wrapper: PathBuf, policy: P) -> Self + pub fn new(bash_path: PathBuf, execve_wrapper: PathBuf, policy: Policy) -> Self where - P: EscalationPolicy + Send + Sync + 'static, + Policy: EscalationPolicy + Send + Sync + 'static, { Self { bash_path, @@ -84,13 +104,17 @@ impl EscalateServer { &self, params: ExecParams, cancel_rx: CancellationToken, - command_executor: &dyn ShellCommandExecutor, + command_executor: Arc, ) -> anyhow::Result { let (escalate_server, escalate_client) = AsyncDatagramSocket::pair()?; let client_socket = escalate_client.into_inner(); // Only the client endpoint should cross exec into the wrapper process. client_socket.set_cloexec(false)?; - let escalate_task = tokio::spawn(escalate_task(escalate_server, self.policy.clone())); + let escalate_task = tokio::spawn(escalate_task( + escalate_server, + Arc::clone(&self.policy), + Arc::clone(&command_executor), + )); let mut env = std::env::vars().collect::>(); env.insert( ESCALATE_SOCKET_ENV_VAR.to_string(), @@ -126,6 +150,7 @@ impl EscalateServer { async fn escalate_task( socket: AsyncDatagramSocket, policy: Arc, + command_executor: Arc, ) -> anyhow::Result<()> { loop { let (_, mut fds) = socket.receive_with_fds().await?; @@ -134,9 +159,12 @@ async fn escalate_task( continue; } let stream_socket = AsyncSocket::from_fd(fds.remove(0))?; - let policy = policy.clone(); + let policy = Arc::clone(&policy); + let command_executor = Arc::clone(&command_executor); tokio::spawn(async move { - if let Err(err) = handle_escalate_session_with_policy(stream_socket, policy).await { + if let Err(err) = + handle_escalate_session_with_policy(stream_socket, policy, command_executor).await + { tracing::error!("escalate session failed: {err:?}"); } }); @@ -146,6 +174,7 @@ async fn escalate_task( async fn handle_escalate_session_with_policy( socket: AsyncSocket, policy: Arc, + command_executor: Arc, ) -> anyhow::Result<()> { let EscalateRequest { file, @@ -154,22 +183,22 @@ async fn handle_escalate_session_with_policy( env, } = socket.receive::().await?; let program = AbsolutePathBuf::resolve_path_against_base(file, workdir.as_path())?; - let action = policy + let decision = policy .determine_action(&program, &argv, &workdir) .await .context("failed to determine escalation action")?; - tracing::debug!("decided {action:?} for {program:?} {argv:?} {workdir:?}"); + tracing::debug!("decided {decision:?} for {program:?} {argv:?} {workdir:?}"); - match action { - EscalateAction::Run => { + match decision { + EscalationDecision::Run => { socket .send(EscalateResponse { action: EscalateAction::Run, }) .await?; } - EscalateAction::Escalate => { + EscalationDecision::Escalate(execution) => { socket .send(EscalateResponse { action: EscalateAction::Escalate, @@ -197,12 +226,23 @@ async fn handle_escalate_session_with_policy( )); } - let mut command = Command::new(program.as_path()); + let PreparedExec { + command, + cwd, + env, + arg0, + } = command_executor + .prepare_escalated_exec(&program, &argv, &workdir, env, execution) + .await?; + let (program, args) = command + .split_first() + .ok_or_else(|| anyhow::anyhow!("prepared escalated command must not be empty"))?; + let mut command = Command::new(program); command - .args(&argv[1..]) - .arg0(argv[0].clone()) + .args(args) + .arg0(arg0.unwrap_or_else(|| program.clone())) .envs(&env) - .current_dir(&workdir) + .current_dir(&cwd) .stdin(Stdio::null()) .stdout(Stdio::null()) .stderr(Stdio::null()); @@ -222,7 +262,7 @@ async fn handle_escalate_session_with_policy( }) .await?; } - EscalateAction::Deny { reason } => { + EscalationDecision::Deny { reason } => { socket .send(EscalateResponse { action: EscalateAction::Deny { reason }, @@ -236,13 +276,15 @@ async fn handle_escalate_session_with_policy( #[cfg(test)] mod tests { use super::*; + use codex_protocol::approvals::EscalationPermissions; + use codex_protocol::models::PermissionProfile; use codex_utils_absolute_path::AbsolutePathBuf; use pretty_assertions::assert_eq; use std::collections::HashMap; use std::path::PathBuf; struct DeterministicEscalationPolicy { - action: EscalateAction, + decision: EscalationDecision, } #[async_trait::async_trait] @@ -252,8 +294,8 @@ mod tests { _file: &AbsolutePathBuf, _argv: &[String], _workdir: &AbsolutePathBuf, - ) -> anyhow::Result { - Ok(self.action.clone()) + ) -> anyhow::Result { + Ok(self.decision.clone()) } } @@ -269,10 +311,82 @@ mod tests { file: &AbsolutePathBuf, _argv: &[String], workdir: &AbsolutePathBuf, - ) -> anyhow::Result { + ) -> anyhow::Result { assert_eq!(file, &self.expected_file); assert_eq!(workdir, &self.expected_workdir); - Ok(EscalateAction::Run) + Ok(EscalationDecision::run()) + } + } + + struct ForwardingShellCommandExecutor; + + #[async_trait::async_trait] + impl ShellCommandExecutor for ForwardingShellCommandExecutor { + async fn run( + &self, + _command: Vec, + _cwd: PathBuf, + _env: HashMap, + _cancel_rx: CancellationToken, + ) -> anyhow::Result { + unreachable!("run() is not used by handle_escalate_session_with_policy() tests") + } + + async fn prepare_escalated_exec( + &self, + program: &AbsolutePathBuf, + argv: &[String], + workdir: &AbsolutePathBuf, + env: HashMap, + _execution: EscalationExecution, + ) -> anyhow::Result { + Ok(PreparedExec { + command: std::iter::once(program.to_string_lossy().to_string()) + .chain(argv.iter().skip(1).cloned()) + .collect(), + cwd: workdir.to_path_buf(), + env, + arg0: argv.first().cloned(), + }) + } + } + + struct PermissionAssertingShellCommandExecutor { + expected_permissions: EscalationPermissions, + } + + #[async_trait::async_trait] + impl ShellCommandExecutor for PermissionAssertingShellCommandExecutor { + async fn run( + &self, + _command: Vec, + _cwd: PathBuf, + _env: HashMap, + _cancel_rx: CancellationToken, + ) -> anyhow::Result { + unreachable!("run() is not used by handle_escalate_session_with_policy() tests") + } + + async fn prepare_escalated_exec( + &self, + program: &AbsolutePathBuf, + argv: &[String], + workdir: &AbsolutePathBuf, + env: HashMap, + execution: EscalationExecution, + ) -> anyhow::Result { + assert_eq!( + execution, + EscalationExecution::Permissions(self.expected_permissions.clone()) + ); + Ok(PreparedExec { + command: std::iter::once(program.to_string_lossy().to_string()) + .chain(argv.iter().skip(1).cloned()) + .collect(), + cwd: workdir.to_path_buf(), + env, + arg0: argv.first().cloned(), + }) } } @@ -282,8 +396,9 @@ mod tests { let server_task = tokio::spawn(handle_escalate_session_with_policy( server, Arc::new(DeterministicEscalationPolicy { - action: EscalateAction::Run, + decision: EscalationDecision::run(), }), + Arc::new(ForwardingShellCommandExecutor), )); let mut env = HashMap::new(); @@ -326,6 +441,7 @@ mod tests { expected_file, expected_workdir: workdir.clone(), }), + Arc::new(ForwardingShellCommandExecutor), )); client @@ -353,8 +469,9 @@ mod tests { let server_task = tokio::spawn(handle_escalate_session_with_policy( server, Arc::new(DeterministicEscalationPolicy { - action: EscalateAction::Escalate, + decision: EscalationDecision::escalate(EscalationExecution::Unsandboxed), }), + Arc::new(ForwardingShellCommandExecutor), )); client @@ -387,4 +504,52 @@ mod tests { server_task.await? } + + #[tokio::test] + async fn handle_escalate_session_passes_permissions_to_executor() -> anyhow::Result<()> { + let (server, client) = AsyncSocket::pair()?; + let server_task = tokio::spawn(handle_escalate_session_with_policy( + server, + Arc::new(DeterministicEscalationPolicy { + decision: EscalationDecision::escalate(EscalationExecution::Permissions( + EscalationPermissions::PermissionProfile(PermissionProfile { + network: Some(true), + ..Default::default() + }), + )), + }), + Arc::new(PermissionAssertingShellCommandExecutor { + expected_permissions: EscalationPermissions::PermissionProfile(PermissionProfile { + network: Some(true), + ..Default::default() + }), + }), + )); + + client + .send(EscalateRequest { + file: PathBuf::from("/bin/sh"), + argv: vec!["sh".to_string(), "-c".to_string(), "exit 0".to_string()], + workdir: AbsolutePathBuf::current_dir()?, + env: HashMap::new(), + }) + .await?; + + let response = client.receive::().await?; + assert_eq!( + EscalateResponse { + action: EscalateAction::Escalate, + }, + response + ); + + client + .send_with_fds(SuperExecMessage { fds: Vec::new() }, &[]) + .await?; + + let result = client.receive::().await?; + assert_eq!(0, result.exit_code); + + server_task.await? + } } diff --git a/codex-rs/shell-escalation/src/unix/escalation_policy.rs b/codex-rs/shell-escalation/src/unix/escalation_policy.rs index 28e04e6f590..c3da252dd9f 100644 --- a/codex-rs/shell-escalation/src/unix/escalation_policy.rs +++ b/codex-rs/shell-escalation/src/unix/escalation_policy.rs @@ -1,6 +1,6 @@ use codex_utils_absolute_path::AbsolutePathBuf; -use crate::unix::escalate_protocol::EscalateAction; +use crate::unix::escalate_protocol::EscalationDecision; /// Decides what action to take in response to an execve request from a client. #[async_trait::async_trait] @@ -10,5 +10,5 @@ pub trait EscalationPolicy: Send + Sync { file: &AbsolutePathBuf, argv: &[String], workdir: &AbsolutePathBuf, - ) -> anyhow::Result; + ) -> anyhow::Result; } diff --git a/codex-rs/shell-escalation/src/unix/mod.rs b/codex-rs/shell-escalation/src/unix/mod.rs index 37e29e87754..6de12297a46 100644 --- a/codex-rs/shell-escalation/src/unix/mod.rs +++ b/codex-rs/shell-escalation/src/unix/mod.rs @@ -63,10 +63,15 @@ pub mod stopwatch; pub use self::escalate_client::run_shell_escalation_execve_wrapper; pub use self::escalate_protocol::EscalateAction; +pub use self::escalate_protocol::EscalationDecision; +pub use self::escalate_protocol::EscalationExecution; pub use self::escalate_server::EscalateServer; pub use self::escalate_server::ExecParams; pub use self::escalate_server::ExecResult; +pub use self::escalate_server::PreparedExec; pub use self::escalate_server::ShellCommandExecutor; pub use self::escalation_policy::EscalationPolicy; pub use self::execve_wrapper::main_execve_wrapper; pub use self::stopwatch::Stopwatch; +pub use codex_protocol::approvals::EscalationPermissions; +pub use codex_protocol::approvals::Permissions; diff --git a/codex-rs/state/src/extract.rs b/codex-rs/state/src/extract.rs index 0ba3df6790b..f9782389790 100644 --- a/codex-rs/state/src/extract.rs +++ b/codex-rs/state/src/extract.rs @@ -56,7 +56,9 @@ fn apply_session_meta_from_item(metadata: &mut ThreadMetadata, meta_line: &Sessi } fn apply_turn_context(metadata: &mut ThreadMetadata, turn_ctx: &TurnContextItem) { - metadata.cwd = turn_ctx.cwd.clone(); + if metadata.cwd.as_os_str().is_empty() { + metadata.cwd = turn_ctx.cwd.clone(); + } metadata.sandbox_policy = enum_to_string(&turn_ctx.sandbox_policy); metadata.approval_mode = enum_to_string(&turn_ctx.approval_policy); } @@ -125,10 +127,17 @@ mod tests { use chrono::DateTime; use chrono::Utc; use codex_protocol::ThreadId; + use codex_protocol::config_types::ReasoningSummary; use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseItem; + use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::RolloutItem; + use codex_protocol::protocol::SandboxPolicy; + use codex_protocol::protocol::SessionMeta; + use codex_protocol::protocol::SessionMetaLine; + use codex_protocol::protocol::SessionSource; + use codex_protocol::protocol::TurnContextItem; use codex_protocol::protocol::USER_MESSAGE_BEGIN; use codex_protocol::protocol::UserMessageEvent; @@ -209,6 +218,97 @@ mod tests { assert_eq!(metadata.title, ""); } + #[test] + fn turn_context_does_not_override_session_cwd() { + let mut metadata = metadata_for_test(); + metadata.cwd = PathBuf::new(); + let thread_id = metadata.id; + + apply_rollout_item( + &mut metadata, + &RolloutItem::SessionMeta(SessionMetaLine { + meta: SessionMeta { + id: thread_id, + forked_from_id: Some( + ThreadId::from_string(&Uuid::now_v7().to_string()).expect("thread id"), + ), + timestamp: "2026-02-26T00:00:00.000Z".to_string(), + cwd: PathBuf::from("/child/worktree"), + originator: "codex_cli_rs".to_string(), + cli_version: "0.0.0".to_string(), + source: SessionSource::Cli, + agent_nickname: None, + agent_role: None, + model_provider: Some("openai".to_string()), + base_instructions: None, + dynamic_tools: None, + }, + git: None, + }), + "test-provider", + ); + apply_rollout_item( + &mut metadata, + &RolloutItem::TurnContext(TurnContextItem { + turn_id: Some("turn-1".to_string()), + cwd: PathBuf::from("/parent/workspace"), + current_date: None, + timezone: None, + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + network: None, + model: "gpt-5".to_string(), + personality: None, + collaboration_mode: None, + effort: None, + summary: ReasoningSummary::Auto, + user_instructions: None, + developer_instructions: None, + final_output_json_schema: None, + truncation_policy: None, + }), + "test-provider", + ); + + assert_eq!(metadata.cwd, PathBuf::from("/child/worktree")); + assert_eq!( + metadata.sandbox_policy, + super::enum_to_string(&SandboxPolicy::DangerFullAccess) + ); + assert_eq!(metadata.approval_mode, "never"); + } + + #[test] + fn turn_context_sets_cwd_when_session_cwd_missing() { + let mut metadata = metadata_for_test(); + metadata.cwd = PathBuf::new(); + + apply_rollout_item( + &mut metadata, + &RolloutItem::TurnContext(TurnContextItem { + turn_id: Some("turn-1".to_string()), + cwd: PathBuf::from("/fallback/workspace"), + current_date: None, + timezone: None, + approval_policy: AskForApproval::OnRequest, + sandbox_policy: SandboxPolicy::new_read_only_policy(), + network: None, + model: "gpt-5".to_string(), + personality: None, + collaboration_mode: None, + effort: None, + summary: ReasoningSummary::Auto, + user_instructions: None, + developer_instructions: None, + final_output_json_schema: None, + truncation_policy: None, + }), + "test-provider", + ); + + assert_eq!(metadata.cwd, PathBuf::from("/fallback/workspace")); + } + fn metadata_for_test() -> ThreadMetadata { let id = ThreadId::from_string(&Uuid::from_u128(42).to_string()).expect("thread id"); let created_at = DateTime::::from_timestamp(1_735_689_600, 0).expect("timestamp"); diff --git a/codex-rs/state/src/model/memories.rs b/codex-rs/state/src/model/memories.rs index 6c88d7360e4..0e663bf9048 100644 --- a/codex-rs/state/src/model/memories.rs +++ b/codex-rs/state/src/model/memories.rs @@ -18,6 +18,7 @@ pub struct Stage1Output { pub rollout_summary: String, pub rollout_slug: Option, pub cwd: PathBuf, + pub git_branch: Option, pub generated_at: DateTime, } @@ -45,6 +46,7 @@ pub(crate) struct Stage1OutputRow { rollout_summary: String, rollout_slug: Option, cwd: String, + git_branch: Option, generated_at: i64, } @@ -58,6 +60,7 @@ impl Stage1OutputRow { rollout_summary: row.try_get("rollout_summary")?, rollout_slug: row.try_get("rollout_slug")?, cwd: row.try_get("cwd")?, + git_branch: row.try_get("git_branch")?, generated_at: row.try_get("generated_at")?, }) } @@ -75,6 +78,7 @@ impl TryFrom for Stage1Output { rollout_summary: row.rollout_summary, rollout_slug: row.rollout_slug, cwd: PathBuf::from(row.cwd), + git_branch: row.git_branch, generated_at: epoch_seconds_to_datetime(row.generated_at)?, }) } diff --git a/codex-rs/state/src/runtime.rs b/codex-rs/state/src/runtime.rs index b7ee2ac883a..42d235fc765 100644 --- a/codex-rs/state/src/runtime.rs +++ b/codex-rs/state/src/runtime.rs @@ -18,7 +18,6 @@ use crate::ThreadMetadataBuilder; use crate::ThreadsPage; use crate::apply_rollout_item; use crate::migrations::MIGRATOR; -use crate::model::AgentJobItemRow; use crate::model::AgentJobRow; use crate::model::ThreadRow; use crate::model::anchor_from_item; @@ -48,10 +47,14 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; use tracing::warn; -use uuid::Uuid; +mod agent_jobs; +mod backfill; +mod logs; mod memories; -// Memory-specific CRUD and phase job lifecycle methods live in `runtime/memories.rs`. +#[cfg(test)] +mod test_support; +mod threads; // "Partition" is the retention bucket we cap at 10 MiB: // - one bucket per non-null thread_id @@ -107,1461 +110,6 @@ impl StateRuntime { pub fn codex_home(&self) -> &Path { self.codex_home.as_path() } - - /// Get persisted rollout metadata backfill state. - pub async fn get_backfill_state(&self) -> anyhow::Result { - self.ensure_backfill_state_row().await?; - let row = sqlx::query( - r#" -SELECT status, last_watermark, last_success_at -FROM backfill_state -WHERE id = 1 - "#, - ) - .fetch_one(self.pool.as_ref()) - .await?; - crate::BackfillState::try_from_row(&row) - } - - /// Attempt to claim ownership of rollout metadata backfill. - /// - /// Returns `true` when this runtime claimed the backfill worker slot. - /// Returns `false` if backfill is already complete or currently owned by a - /// non-expired worker. - pub async fn try_claim_backfill(&self, lease_seconds: i64) -> anyhow::Result { - self.ensure_backfill_state_row().await?; - let now = Utc::now().timestamp(); - let lease_cutoff = now.saturating_sub(lease_seconds.max(0)); - let result = sqlx::query( - r#" -UPDATE backfill_state -SET status = ?, updated_at = ? -WHERE id = 1 - AND status != ? - AND (status != ? OR updated_at <= ?) - "#, - ) - .bind(crate::BackfillStatus::Running.as_str()) - .bind(now) - .bind(crate::BackfillStatus::Complete.as_str()) - .bind(crate::BackfillStatus::Running.as_str()) - .bind(lease_cutoff) - .execute(self.pool.as_ref()) - .await?; - Ok(result.rows_affected() == 1) - } - - /// Mark rollout metadata backfill as running. - pub async fn mark_backfill_running(&self) -> anyhow::Result<()> { - self.ensure_backfill_state_row().await?; - sqlx::query( - r#" -UPDATE backfill_state -SET status = ?, updated_at = ? -WHERE id = 1 - "#, - ) - .bind(crate::BackfillStatus::Running.as_str()) - .bind(Utc::now().timestamp()) - .execute(self.pool.as_ref()) - .await?; - Ok(()) - } - - /// Persist rollout metadata backfill progress. - pub async fn checkpoint_backfill(&self, watermark: &str) -> anyhow::Result<()> { - self.ensure_backfill_state_row().await?; - sqlx::query( - r#" -UPDATE backfill_state -SET status = ?, last_watermark = ?, updated_at = ? -WHERE id = 1 - "#, - ) - .bind(crate::BackfillStatus::Running.as_str()) - .bind(watermark) - .bind(Utc::now().timestamp()) - .execute(self.pool.as_ref()) - .await?; - Ok(()) - } - - /// Mark rollout metadata backfill as complete. - pub async fn mark_backfill_complete(&self, last_watermark: Option<&str>) -> anyhow::Result<()> { - self.ensure_backfill_state_row().await?; - let now = Utc::now().timestamp(); - sqlx::query( - r#" -UPDATE backfill_state -SET - status = ?, - last_watermark = COALESCE(?, last_watermark), - last_success_at = ?, - updated_at = ? -WHERE id = 1 - "#, - ) - .bind(crate::BackfillStatus::Complete.as_str()) - .bind(last_watermark) - .bind(now) - .bind(now) - .execute(self.pool.as_ref()) - .await?; - Ok(()) - } - - /// Load thread metadata by id using the underlying database. - pub async fn get_thread(&self, id: ThreadId) -> anyhow::Result> { - let row = sqlx::query( - r#" -SELECT - id, - rollout_path, - created_at, - updated_at, - source, - agent_nickname, - agent_role, - model_provider, - cwd, - cli_version, - title, - sandbox_policy, - approval_mode, - tokens_used, - first_user_message, - archived_at, - git_sha, - git_branch, - git_origin_url -FROM threads -WHERE id = ? - "#, - ) - .bind(id.to_string()) - .fetch_optional(self.pool.as_ref()) - .await?; - row.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from)) - .transpose() - } - - /// Get dynamic tools for a thread, if present. - pub async fn get_dynamic_tools( - &self, - thread_id: ThreadId, - ) -> anyhow::Result>> { - let rows = sqlx::query( - r#" -SELECT name, description, input_schema -FROM thread_dynamic_tools -WHERE thread_id = ? -ORDER BY position ASC - "#, - ) - .bind(thread_id.to_string()) - .fetch_all(self.pool.as_ref()) - .await?; - if rows.is_empty() { - return Ok(None); - } - let mut tools = Vec::with_capacity(rows.len()); - for row in rows { - let input_schema: String = row.try_get("input_schema")?; - let input_schema = serde_json::from_str::(input_schema.as_str())?; - tools.push(DynamicToolSpec { - name: row.try_get("name")?, - description: row.try_get("description")?, - input_schema, - }); - } - Ok(Some(tools)) - } - - /// Find a rollout path by thread id using the underlying database. - pub async fn find_rollout_path_by_id( - &self, - id: ThreadId, - archived_only: Option, - ) -> anyhow::Result> { - let mut builder = - QueryBuilder::::new("SELECT rollout_path FROM threads WHERE id = "); - builder.push_bind(id.to_string()); - match archived_only { - Some(true) => { - builder.push(" AND archived = 1"); - } - Some(false) => { - builder.push(" AND archived = 0"); - } - None => {} - } - let row = builder.build().fetch_optional(self.pool.as_ref()).await?; - Ok(row - .and_then(|r| r.try_get::("rollout_path").ok()) - .map(PathBuf::from)) - } - - /// List threads using the underlying database. - #[allow(clippy::too_many_arguments)] - pub async fn list_threads( - &self, - page_size: usize, - anchor: Option<&crate::Anchor>, - sort_key: crate::SortKey, - allowed_sources: &[String], - model_providers: Option<&[String]>, - archived_only: bool, - search_term: Option<&str>, - ) -> anyhow::Result { - let limit = page_size.saturating_add(1); - - let mut builder = QueryBuilder::::new( - r#" -SELECT - id, - rollout_path, - created_at, - updated_at, - source, - agent_nickname, - agent_role, - model_provider, - cwd, - cli_version, - title, - sandbox_policy, - approval_mode, - tokens_used, - first_user_message, - archived_at, - git_sha, - git_branch, - git_origin_url -FROM threads - "#, - ); - push_thread_filters( - &mut builder, - archived_only, - allowed_sources, - model_providers, - anchor, - sort_key, - search_term, - ); - push_thread_order_and_limit(&mut builder, sort_key, limit); - - let rows = builder.build().fetch_all(self.pool.as_ref()).await?; - let mut items = rows - .into_iter() - .map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from)) - .collect::, _>>()?; - let num_scanned_rows = items.len(); - let next_anchor = if items.len() > page_size { - items.pop(); - items - .last() - .and_then(|item| anchor_from_item(item, sort_key)) - } else { - None - }; - Ok(ThreadsPage { - items, - next_anchor, - num_scanned_rows, - }) - } - - /// Insert one log entry into the logs table. - pub async fn insert_log(&self, entry: &LogEntry) -> anyhow::Result<()> { - self.insert_logs(std::slice::from_ref(entry)).await - } - - /// Insert a batch of log entries into the logs table. - pub async fn insert_logs(&self, entries: &[LogEntry]) -> anyhow::Result<()> { - if entries.is_empty() { - return Ok(()); - } - - let mut tx = self.pool.begin().await?; - let mut builder = QueryBuilder::::new( - "INSERT INTO logs (ts, ts_nanos, level, target, message, thread_id, process_uuid, module_path, file, line, estimated_bytes) ", - ); - builder.push_values(entries, |mut row, entry| { - let estimated_bytes = entry.message.as_ref().map_or(0, String::len) as i64 - + entry.level.len() as i64 - + entry.target.len() as i64 - + entry.module_path.as_ref().map_or(0, String::len) as i64 - + entry.file.as_ref().map_or(0, String::len) as i64; - row.push_bind(entry.ts) - .push_bind(entry.ts_nanos) - .push_bind(&entry.level) - .push_bind(&entry.target) - .push_bind(&entry.message) - .push_bind(&entry.thread_id) - .push_bind(&entry.process_uuid) - .push_bind(&entry.module_path) - .push_bind(&entry.file) - .push_bind(entry.line) - .push_bind(estimated_bytes); - }); - builder.build().execute(&mut *tx).await?; - self.prune_logs_after_insert(entries, &mut tx).await?; - tx.commit().await?; - Ok(()) - } - - /// Enforce per-partition log size caps after a successful batch insert. - /// - /// We maintain two independent budgets: - /// - Thread logs: rows with `thread_id IS NOT NULL`, capped per `thread_id`. - /// - Threadless process logs: rows with `thread_id IS NULL` ("threadless"), - /// capped per `process_uuid` (including `process_uuid IS NULL` as its own - /// threadless partition). - /// - /// "Threadless" means the log row is not associated with any conversation - /// thread, so retention is keyed by process identity instead. - /// - /// This runs inside the same transaction as the insert so callers never - /// observe "inserted but not yet pruned" rows. - async fn prune_logs_after_insert( - &self, - entries: &[LogEntry], - tx: &mut SqliteConnection, - ) -> anyhow::Result<()> { - let thread_ids: BTreeSet<&str> = entries - .iter() - .filter_map(|entry| entry.thread_id.as_deref()) - .collect(); - if !thread_ids.is_empty() { - // Cheap precheck: only run the heavier window-function prune for - // threads that are currently above the cap. - let mut over_limit_threads_query = - QueryBuilder::::new("SELECT thread_id FROM logs WHERE thread_id IN ("); - { - let mut separated = over_limit_threads_query.separated(", "); - for thread_id in &thread_ids { - separated.push_bind(*thread_id); - } - } - over_limit_threads_query.push(") GROUP BY thread_id HAVING SUM("); - over_limit_threads_query.push("estimated_bytes"); - over_limit_threads_query.push(") > "); - over_limit_threads_query.push_bind(LOG_PARTITION_SIZE_LIMIT_BYTES); - let over_limit_thread_ids: Vec = over_limit_threads_query - .build() - .fetch_all(&mut *tx) - .await? - .into_iter() - .map(|row| row.try_get("thread_id")) - .collect::>()?; - if !over_limit_thread_ids.is_empty() { - // Enforce a strict per-thread cap by deleting every row whose - // newest-first cumulative bytes exceed the partition budget. - let mut prune_threads = QueryBuilder::::new( - r#" -DELETE FROM logs -WHERE id IN ( - SELECT id - FROM ( - SELECT - id, - SUM( -"#, - ); - prune_threads.push("estimated_bytes"); - prune_threads.push( - r#" - ) OVER ( - PARTITION BY thread_id - ORDER BY ts DESC, ts_nanos DESC, id DESC - ) AS cumulative_bytes - FROM logs - WHERE thread_id IN ( -"#, - ); - { - let mut separated = prune_threads.separated(", "); - for thread_id in &over_limit_thread_ids { - separated.push_bind(thread_id); - } - } - prune_threads.push( - r#" - ) - ) - WHERE cumulative_bytes > -"#, - ); - prune_threads.push_bind(LOG_PARTITION_SIZE_LIMIT_BYTES); - prune_threads.push("\n)"); - prune_threads.build().execute(&mut *tx).await?; - } - } - - let threadless_process_uuids: BTreeSet<&str> = entries - .iter() - .filter(|entry| entry.thread_id.is_none()) - .filter_map(|entry| entry.process_uuid.as_deref()) - .collect(); - let has_threadless_null_process_uuid = entries - .iter() - .any(|entry| entry.thread_id.is_none() && entry.process_uuid.is_none()); - if !threadless_process_uuids.is_empty() { - // Threadless logs are budgeted separately per process UUID. - let mut over_limit_processes_query = QueryBuilder::::new( - "SELECT process_uuid FROM logs WHERE thread_id IS NULL AND process_uuid IN (", - ); - { - let mut separated = over_limit_processes_query.separated(", "); - for process_uuid in &threadless_process_uuids { - separated.push_bind(*process_uuid); - } - } - over_limit_processes_query.push(") GROUP BY process_uuid HAVING SUM("); - over_limit_processes_query.push("estimated_bytes"); - over_limit_processes_query.push(") > "); - over_limit_processes_query.push_bind(LOG_PARTITION_SIZE_LIMIT_BYTES); - let over_limit_process_uuids: Vec = over_limit_processes_query - .build() - .fetch_all(&mut *tx) - .await? - .into_iter() - .map(|row| row.try_get("process_uuid")) - .collect::>()?; - if !over_limit_process_uuids.is_empty() { - // Same strict cap policy as thread pruning, but only for - // threadless rows in the affected process UUIDs. - let mut prune_threadless_process_logs = QueryBuilder::::new( - r#" -DELETE FROM logs -WHERE id IN ( - SELECT id - FROM ( - SELECT - id, - SUM( -"#, - ); - prune_threadless_process_logs.push("estimated_bytes"); - prune_threadless_process_logs.push( - r#" - ) OVER ( - PARTITION BY process_uuid - ORDER BY ts DESC, ts_nanos DESC, id DESC - ) AS cumulative_bytes - FROM logs - WHERE thread_id IS NULL - AND process_uuid IN ( -"#, - ); - { - let mut separated = prune_threadless_process_logs.separated(", "); - for process_uuid in &over_limit_process_uuids { - separated.push_bind(process_uuid); - } - } - prune_threadless_process_logs.push( - r#" - ) - ) - WHERE cumulative_bytes > -"#, - ); - prune_threadless_process_logs.push_bind(LOG_PARTITION_SIZE_LIMIT_BYTES); - prune_threadless_process_logs.push("\n)"); - prune_threadless_process_logs - .build() - .execute(&mut *tx) - .await?; - } - } - if has_threadless_null_process_uuid { - // Rows without a process UUID still need a cap; treat NULL as its - // own threadless partition. - let mut null_process_usage_query = QueryBuilder::::new("SELECT SUM("); - null_process_usage_query.push("estimated_bytes"); - null_process_usage_query.push( - ") AS total_bytes FROM logs WHERE thread_id IS NULL AND process_uuid IS NULL", - ); - let total_null_process_bytes: Option = null_process_usage_query - .build() - .fetch_one(&mut *tx) - .await? - .try_get("total_bytes")?; - - if total_null_process_bytes.unwrap_or(0) > LOG_PARTITION_SIZE_LIMIT_BYTES { - let mut prune_threadless_null_process_logs = QueryBuilder::::new( - r#" -DELETE FROM logs -WHERE id IN ( - SELECT id - FROM ( - SELECT - id, - SUM( -"#, - ); - prune_threadless_null_process_logs.push("estimated_bytes"); - prune_threadless_null_process_logs.push( - r#" - ) OVER ( - PARTITION BY process_uuid - ORDER BY ts DESC, ts_nanos DESC, id DESC - ) AS cumulative_bytes - FROM logs - WHERE thread_id IS NULL - AND process_uuid IS NULL - ) - WHERE cumulative_bytes > -"#, - ); - prune_threadless_null_process_logs.push_bind(LOG_PARTITION_SIZE_LIMIT_BYTES); - prune_threadless_null_process_logs.push("\n)"); - prune_threadless_null_process_logs - .build() - .execute(&mut *tx) - .await?; - } - } - Ok(()) - } - - pub(crate) async fn delete_logs_before(&self, cutoff_ts: i64) -> anyhow::Result { - let result = sqlx::query("DELETE FROM logs WHERE ts < ?") - .bind(cutoff_ts) - .execute(self.pool.as_ref()) - .await?; - Ok(result.rows_affected()) - } - - /// Query logs with optional filters. - pub async fn query_logs(&self, query: &LogQuery) -> anyhow::Result> { - let mut builder = QueryBuilder::::new( - "SELECT id, ts, ts_nanos, level, target, message, thread_id, process_uuid, file, line FROM logs WHERE 1 = 1", - ); - push_log_filters(&mut builder, query); - if query.descending { - builder.push(" ORDER BY id DESC"); - } else { - builder.push(" ORDER BY id ASC"); - } - if let Some(limit) = query.limit { - builder.push(" LIMIT ").push_bind(limit as i64); - } - - let rows = builder - .build_query_as::() - .fetch_all(self.pool.as_ref()) - .await?; - Ok(rows) - } - - /// Return the max log id matching optional filters. - pub async fn max_log_id(&self, query: &LogQuery) -> anyhow::Result { - let mut builder = - QueryBuilder::::new("SELECT MAX(id) AS max_id FROM logs WHERE 1 = 1"); - push_log_filters(&mut builder, query); - let row = builder.build().fetch_one(self.pool.as_ref()).await?; - let max_id: Option = row.try_get("max_id")?; - Ok(max_id.unwrap_or(0)) - } - - /// List thread ids using the underlying database (no rollout scanning). - pub async fn list_thread_ids( - &self, - limit: usize, - anchor: Option<&crate::Anchor>, - sort_key: crate::SortKey, - allowed_sources: &[String], - model_providers: Option<&[String]>, - archived_only: bool, - ) -> anyhow::Result> { - let mut builder = QueryBuilder::::new("SELECT id FROM threads"); - push_thread_filters( - &mut builder, - archived_only, - allowed_sources, - model_providers, - anchor, - sort_key, - None, - ); - push_thread_order_and_limit(&mut builder, sort_key, limit); - - let rows = builder.build().fetch_all(self.pool.as_ref()).await?; - rows.into_iter() - .map(|row| { - let id: String = row.try_get("id")?; - Ok(ThreadId::try_from(id)?) - }) - .collect() - } - - /// Insert or replace thread metadata directly. - pub async fn upsert_thread(&self, metadata: &crate::ThreadMetadata) -> anyhow::Result<()> { - sqlx::query( - r#" -INSERT INTO threads ( - id, - rollout_path, - created_at, - updated_at, - source, - agent_nickname, - agent_role, - model_provider, - cwd, - cli_version, - title, - sandbox_policy, - approval_mode, - tokens_used, - first_user_message, - archived, - archived_at, - git_sha, - git_branch, - git_origin_url -) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) -ON CONFLICT(id) DO UPDATE SET - rollout_path = excluded.rollout_path, - created_at = excluded.created_at, - updated_at = excluded.updated_at, - source = excluded.source, - agent_nickname = excluded.agent_nickname, - agent_role = excluded.agent_role, - model_provider = excluded.model_provider, - cwd = excluded.cwd, - cli_version = excluded.cli_version, - title = excluded.title, - sandbox_policy = excluded.sandbox_policy, - approval_mode = excluded.approval_mode, - tokens_used = excluded.tokens_used, - first_user_message = excluded.first_user_message, - archived = excluded.archived, - archived_at = excluded.archived_at, - git_sha = excluded.git_sha, - git_branch = excluded.git_branch, - git_origin_url = excluded.git_origin_url - "#, - ) - .bind(metadata.id.to_string()) - .bind(metadata.rollout_path.display().to_string()) - .bind(datetime_to_epoch_seconds(metadata.created_at)) - .bind(datetime_to_epoch_seconds(metadata.updated_at)) - .bind(metadata.source.as_str()) - .bind(metadata.agent_nickname.as_deref()) - .bind(metadata.agent_role.as_deref()) - .bind(metadata.model_provider.as_str()) - .bind(metadata.cwd.display().to_string()) - .bind(metadata.cli_version.as_str()) - .bind(metadata.title.as_str()) - .bind(metadata.sandbox_policy.as_str()) - .bind(metadata.approval_mode.as_str()) - .bind(metadata.tokens_used) - .bind(metadata.first_user_message.as_deref().unwrap_or_default()) - .bind(metadata.archived_at.is_some()) - .bind(metadata.archived_at.map(datetime_to_epoch_seconds)) - .bind(metadata.git_sha.as_deref()) - .bind(metadata.git_branch.as_deref()) - .bind(metadata.git_origin_url.as_deref()) - .execute(self.pool.as_ref()) - .await?; - Ok(()) - } - - /// Persist dynamic tools for a thread if none have been stored yet. - /// - /// Dynamic tools are defined at thread start and should not change afterward. - /// This only writes the first time we see tools for a given thread. - pub async fn persist_dynamic_tools( - &self, - thread_id: ThreadId, - tools: Option<&[DynamicToolSpec]>, - ) -> anyhow::Result<()> { - let Some(tools) = tools else { - return Ok(()); - }; - if tools.is_empty() { - return Ok(()); - } - let thread_id = thread_id.to_string(); - let mut tx = self.pool.begin().await?; - for (idx, tool) in tools.iter().enumerate() { - let position = i64::try_from(idx).unwrap_or(i64::MAX); - let input_schema = serde_json::to_string(&tool.input_schema)?; - sqlx::query( - r#" -INSERT INTO thread_dynamic_tools ( - thread_id, - position, - name, - description, - input_schema -) VALUES (?, ?, ?, ?, ?) -ON CONFLICT(thread_id, position) DO NOTHING - "#, - ) - .bind(thread_id.as_str()) - .bind(position) - .bind(tool.name.as_str()) - .bind(tool.description.as_str()) - .bind(input_schema) - .execute(&mut *tx) - .await?; - } - tx.commit().await?; - Ok(()) - } - - /// Apply rollout items incrementally using the underlying database. - pub async fn apply_rollout_items( - &self, - builder: &ThreadMetadataBuilder, - items: &[RolloutItem], - otel: Option<&OtelManager>, - ) -> anyhow::Result<()> { - if items.is_empty() { - return Ok(()); - } - let mut metadata = self - .get_thread(builder.id) - .await? - .unwrap_or_else(|| builder.build(&self.default_provider)); - metadata.rollout_path = builder.rollout_path.clone(); - for item in items { - apply_rollout_item(&mut metadata, item, &self.default_provider); - } - if let Some(updated_at) = file_modified_time_utc(builder.rollout_path.as_path()).await { - metadata.updated_at = updated_at; - } - // Keep the thread upsert before dynamic tools to satisfy the foreign key constraint: - // thread_dynamic_tools.thread_id -> threads.id. - if let Err(err) = self.upsert_thread(&metadata).await { - if let Some(otel) = otel { - otel.counter(DB_ERROR_METRIC, 1, &[("stage", "apply_rollout_items")]); - } - return Err(err); - } - let dynamic_tools = extract_dynamic_tools(items); - if let Some(dynamic_tools) = dynamic_tools - && let Err(err) = self - .persist_dynamic_tools(builder.id, dynamic_tools.as_deref()) - .await - { - if let Some(otel) = otel { - otel.counter(DB_ERROR_METRIC, 1, &[("stage", "persist_dynamic_tools")]); - } - return Err(err); - } - Ok(()) - } - - /// Mark a thread as archived using the underlying database. - pub async fn mark_archived( - &self, - thread_id: ThreadId, - rollout_path: &Path, - archived_at: DateTime, - ) -> anyhow::Result<()> { - let Some(mut metadata) = self.get_thread(thread_id).await? else { - return Ok(()); - }; - metadata.archived_at = Some(archived_at); - metadata.rollout_path = rollout_path.to_path_buf(); - if let Some(updated_at) = file_modified_time_utc(rollout_path).await { - metadata.updated_at = updated_at; - } - if metadata.id != thread_id { - warn!( - "thread id mismatch during archive: expected {thread_id}, got {}", - metadata.id - ); - } - self.upsert_thread(&metadata).await - } - - /// Mark a thread as unarchived using the underlying database. - pub async fn mark_unarchived( - &self, - thread_id: ThreadId, - rollout_path: &Path, - ) -> anyhow::Result<()> { - let Some(mut metadata) = self.get_thread(thread_id).await? else { - return Ok(()); - }; - metadata.archived_at = None; - metadata.rollout_path = rollout_path.to_path_buf(); - if let Some(updated_at) = file_modified_time_utc(rollout_path).await { - metadata.updated_at = updated_at; - } - if metadata.id != thread_id { - warn!( - "thread id mismatch during unarchive: expected {thread_id}, got {}", - metadata.id - ); - } - self.upsert_thread(&metadata).await - } - - /// Delete a thread metadata row by id. - pub async fn delete_thread(&self, thread_id: ThreadId) -> anyhow::Result { - let result = sqlx::query("DELETE FROM threads WHERE id = ?") - .bind(thread_id.to_string()) - .execute(self.pool.as_ref()) - .await?; - Ok(result.rows_affected()) - } - - pub async fn create_agent_job( - &self, - params: &AgentJobCreateParams, - items: &[AgentJobItemCreateParams], - ) -> anyhow::Result { - let now = Utc::now().timestamp(); - let input_headers_json = serde_json::to_string(¶ms.input_headers)?; - let output_schema_json = params - .output_schema_json - .as_ref() - .map(serde_json::to_string) - .transpose()?; - let max_runtime_seconds = params - .max_runtime_seconds - .map(i64::try_from) - .transpose() - .map_err(|_| anyhow::anyhow!("invalid max_runtime_seconds value"))?; - let mut tx = self.pool.begin().await?; - sqlx::query( - r#" -INSERT INTO agent_jobs ( - id, - name, - status, - instruction, - auto_export, - max_runtime_seconds, - output_schema_json, - input_headers_json, - input_csv_path, - output_csv_path, - created_at, - updated_at, - started_at, - completed_at, - last_error -) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, NULL, NULL, NULL) - "#, - ) - .bind(params.id.as_str()) - .bind(params.name.as_str()) - .bind(AgentJobStatus::Pending.as_str()) - .bind(params.instruction.as_str()) - .bind(i64::from(params.auto_export)) - .bind(max_runtime_seconds) - .bind(output_schema_json) - .bind(input_headers_json) - .bind(params.input_csv_path.as_str()) - .bind(params.output_csv_path.as_str()) - .bind(now) - .bind(now) - .execute(&mut *tx) - .await?; - - for item in items { - let row_json = serde_json::to_string(&item.row_json)?; - sqlx::query( - r#" -INSERT INTO agent_job_items ( - job_id, - item_id, - row_index, - source_id, - row_json, - status, - assigned_thread_id, - attempt_count, - result_json, - last_error, - created_at, - updated_at, - completed_at, - reported_at -) VALUES (?, ?, ?, ?, ?, ?, NULL, 0, NULL, NULL, ?, ?, NULL, NULL) - "#, - ) - .bind(params.id.as_str()) - .bind(item.item_id.as_str()) - .bind(item.row_index) - .bind(item.source_id.as_deref()) - .bind(row_json) - .bind(AgentJobItemStatus::Pending.as_str()) - .bind(now) - .bind(now) - .execute(&mut *tx) - .await?; - } - - tx.commit().await?; - - let job_id = params.id.as_str(); - self.get_agent_job(job_id) - .await? - .ok_or_else(|| anyhow::anyhow!("failed to load created agent job {job_id}")) - } - - pub async fn get_agent_job(&self, job_id: &str) -> anyhow::Result> { - let row = sqlx::query_as::<_, AgentJobRow>( - r#" -SELECT - id, - name, - status, - instruction, - auto_export, - max_runtime_seconds, - output_schema_json, - input_headers_json, - input_csv_path, - output_csv_path, - created_at, - updated_at, - started_at, - completed_at, - last_error -FROM agent_jobs -WHERE id = ? - "#, - ) - .bind(job_id) - .fetch_optional(self.pool.as_ref()) - .await?; - row.map(AgentJob::try_from).transpose() - } - - pub async fn list_agent_job_items( - &self, - job_id: &str, - status: Option, - limit: Option, - ) -> anyhow::Result> { - let mut builder = QueryBuilder::::new( - r#" -SELECT - job_id, - item_id, - row_index, - source_id, - row_json, - status, - assigned_thread_id, - attempt_count, - result_json, - last_error, - created_at, - updated_at, - completed_at, - reported_at -FROM agent_job_items -WHERE job_id = - "#, - ); - builder.push_bind(job_id); - if let Some(status) = status { - builder.push(" AND status = "); - builder.push_bind(status.as_str()); - } - builder.push(" ORDER BY row_index ASC"); - if let Some(limit) = limit { - builder.push(" LIMIT "); - builder.push_bind(limit as i64); - } - let rows = builder - .build_query_as::() - .fetch_all(self.pool.as_ref()) - .await?; - rows.into_iter().map(AgentJobItem::try_from).collect() - } - - pub async fn get_agent_job_item( - &self, - job_id: &str, - item_id: &str, - ) -> anyhow::Result> { - let row = sqlx::query_as::<_, AgentJobItemRow>( - r#" -SELECT - job_id, - item_id, - row_index, - source_id, - row_json, - status, - assigned_thread_id, - attempt_count, - result_json, - last_error, - created_at, - updated_at, - completed_at, - reported_at -FROM agent_job_items -WHERE job_id = ? AND item_id = ? - "#, - ) - .bind(job_id) - .bind(item_id) - .fetch_optional(self.pool.as_ref()) - .await?; - row.map(AgentJobItem::try_from).transpose() - } - - pub async fn mark_agent_job_running(&self, job_id: &str) -> anyhow::Result<()> { - let now = Utc::now().timestamp(); - sqlx::query( - r#" -UPDATE agent_jobs -SET - status = ?, - updated_at = ?, - started_at = COALESCE(started_at, ?), - completed_at = NULL, - last_error = NULL -WHERE id = ? - "#, - ) - .bind(AgentJobStatus::Running.as_str()) - .bind(now) - .bind(now) - .bind(job_id) - .execute(self.pool.as_ref()) - .await?; - Ok(()) - } - - pub async fn mark_agent_job_completed(&self, job_id: &str) -> anyhow::Result<()> { - let now = Utc::now().timestamp(); - sqlx::query( - r#" -UPDATE agent_jobs -SET status = ?, updated_at = ?, completed_at = ?, last_error = NULL -WHERE id = ? - "#, - ) - .bind(AgentJobStatus::Completed.as_str()) - .bind(now) - .bind(now) - .bind(job_id) - .execute(self.pool.as_ref()) - .await?; - Ok(()) - } - - pub async fn mark_agent_job_failed( - &self, - job_id: &str, - error_message: &str, - ) -> anyhow::Result<()> { - let now = Utc::now().timestamp(); - sqlx::query( - r#" -UPDATE agent_jobs -SET status = ?, updated_at = ?, completed_at = ?, last_error = ? -WHERE id = ? - "#, - ) - .bind(AgentJobStatus::Failed.as_str()) - .bind(now) - .bind(now) - .bind(error_message) - .bind(job_id) - .execute(self.pool.as_ref()) - .await?; - Ok(()) - } - - pub async fn mark_agent_job_cancelled( - &self, - job_id: &str, - reason: &str, - ) -> anyhow::Result { - let now = Utc::now().timestamp(); - let result = sqlx::query( - r#" -UPDATE agent_jobs -SET status = ?, updated_at = ?, completed_at = ?, last_error = ? -WHERE id = ? AND status IN (?, ?) - "#, - ) - .bind(AgentJobStatus::Cancelled.as_str()) - .bind(now) - .bind(now) - .bind(reason) - .bind(job_id) - .bind(AgentJobStatus::Pending.as_str()) - .bind(AgentJobStatus::Running.as_str()) - .execute(self.pool.as_ref()) - .await?; - Ok(result.rows_affected() > 0) - } - - pub async fn is_agent_job_cancelled(&self, job_id: &str) -> anyhow::Result { - let row = sqlx::query( - r#" -SELECT status -FROM agent_jobs -WHERE id = ? - "#, - ) - .bind(job_id) - .fetch_optional(self.pool.as_ref()) - .await?; - let Some(row) = row else { - return Ok(false); - }; - let status: String = row.try_get("status")?; - Ok(AgentJobStatus::parse(status.as_str())? == AgentJobStatus::Cancelled) - } - - pub async fn mark_agent_job_item_running( - &self, - job_id: &str, - item_id: &str, - ) -> anyhow::Result { - let now = Utc::now().timestamp(); - let result = sqlx::query( - r#" -UPDATE agent_job_items -SET - status = ?, - assigned_thread_id = NULL, - attempt_count = attempt_count + 1, - updated_at = ?, - last_error = NULL -WHERE job_id = ? AND item_id = ? AND status = ? - "#, - ) - .bind(AgentJobItemStatus::Running.as_str()) - .bind(now) - .bind(job_id) - .bind(item_id) - .bind(AgentJobItemStatus::Pending.as_str()) - .execute(self.pool.as_ref()) - .await?; - Ok(result.rows_affected() > 0) - } - - pub async fn mark_agent_job_item_running_with_thread( - &self, - job_id: &str, - item_id: &str, - thread_id: &str, - ) -> anyhow::Result { - let now = Utc::now().timestamp(); - let result = sqlx::query( - r#" -UPDATE agent_job_items -SET - status = ?, - assigned_thread_id = ?, - attempt_count = attempt_count + 1, - updated_at = ?, - last_error = NULL -WHERE job_id = ? AND item_id = ? AND status = ? - "#, - ) - .bind(AgentJobItemStatus::Running.as_str()) - .bind(thread_id) - .bind(now) - .bind(job_id) - .bind(item_id) - .bind(AgentJobItemStatus::Pending.as_str()) - .execute(self.pool.as_ref()) - .await?; - Ok(result.rows_affected() > 0) - } - - pub async fn mark_agent_job_item_pending( - &self, - job_id: &str, - item_id: &str, - error_message: Option<&str>, - ) -> anyhow::Result { - let now = Utc::now().timestamp(); - let result = sqlx::query( - r#" -UPDATE agent_job_items -SET - status = ?, - assigned_thread_id = NULL, - updated_at = ?, - last_error = ? -WHERE job_id = ? AND item_id = ? AND status = ? - "#, - ) - .bind(AgentJobItemStatus::Pending.as_str()) - .bind(now) - .bind(error_message) - .bind(job_id) - .bind(item_id) - .bind(AgentJobItemStatus::Running.as_str()) - .execute(self.pool.as_ref()) - .await?; - Ok(result.rows_affected() > 0) - } - - pub async fn set_agent_job_item_thread( - &self, - job_id: &str, - item_id: &str, - thread_id: &str, - ) -> anyhow::Result { - let now = Utc::now().timestamp(); - let result = sqlx::query( - r#" -UPDATE agent_job_items -SET assigned_thread_id = ?, updated_at = ? -WHERE job_id = ? AND item_id = ? AND status = ? - "#, - ) - .bind(thread_id) - .bind(now) - .bind(job_id) - .bind(item_id) - .bind(AgentJobItemStatus::Running.as_str()) - .execute(self.pool.as_ref()) - .await?; - Ok(result.rows_affected() > 0) - } - - pub async fn report_agent_job_item_result( - &self, - job_id: &str, - item_id: &str, - reporting_thread_id: &str, - result_json: &Value, - ) -> anyhow::Result { - let now = Utc::now().timestamp(); - let serialized = serde_json::to_string(result_json)?; - let result = sqlx::query( - r#" -UPDATE agent_job_items -SET - result_json = ?, - reported_at = ?, - updated_at = ?, - last_error = NULL -WHERE - job_id = ? - AND item_id = ? - AND status = ? - AND assigned_thread_id = ? - "#, - ) - .bind(serialized) - .bind(now) - .bind(now) - .bind(job_id) - .bind(item_id) - .bind(AgentJobItemStatus::Running.as_str()) - .bind(reporting_thread_id) - .execute(self.pool.as_ref()) - .await?; - Ok(result.rows_affected() > 0) - } - - pub async fn mark_agent_job_item_completed( - &self, - job_id: &str, - item_id: &str, - ) -> anyhow::Result { - let now = Utc::now().timestamp(); - let result = sqlx::query( - r#" -UPDATE agent_job_items -SET - status = ?, - completed_at = ?, - updated_at = ?, - assigned_thread_id = NULL -WHERE - job_id = ? - AND item_id = ? - AND status = ? - AND result_json IS NOT NULL - "#, - ) - .bind(AgentJobItemStatus::Completed.as_str()) - .bind(now) - .bind(now) - .bind(job_id) - .bind(item_id) - .bind(AgentJobItemStatus::Running.as_str()) - .execute(self.pool.as_ref()) - .await?; - Ok(result.rows_affected() > 0) - } - - pub async fn mark_agent_job_item_failed( - &self, - job_id: &str, - item_id: &str, - error_message: &str, - ) -> anyhow::Result { - let now = Utc::now().timestamp(); - let result = sqlx::query( - r#" -UPDATE agent_job_items -SET - status = ?, - completed_at = ?, - updated_at = ?, - last_error = ?, - assigned_thread_id = NULL -WHERE - job_id = ? - AND item_id = ? - AND status = ? - "#, - ) - .bind(AgentJobItemStatus::Failed.as_str()) - .bind(now) - .bind(now) - .bind(error_message) - .bind(job_id) - .bind(item_id) - .bind(AgentJobItemStatus::Running.as_str()) - .execute(self.pool.as_ref()) - .await?; - Ok(result.rows_affected() > 0) - } - - pub async fn get_agent_job_progress(&self, job_id: &str) -> anyhow::Result { - let row = sqlx::query( - r#" -SELECT - COUNT(*) AS total_items, - SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) AS pending_items, - SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) AS running_items, - SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) AS completed_items, - SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) AS failed_items -FROM agent_job_items -WHERE job_id = ? - "#, - ) - .bind(AgentJobItemStatus::Pending.as_str()) - .bind(AgentJobItemStatus::Running.as_str()) - .bind(AgentJobItemStatus::Completed.as_str()) - .bind(AgentJobItemStatus::Failed.as_str()) - .bind(job_id) - .fetch_one(self.pool.as_ref()) - .await?; - - let total_items: i64 = row.try_get("total_items")?; - let pending_items: Option = row.try_get("pending_items")?; - let running_items: Option = row.try_get("running_items")?; - let completed_items: Option = row.try_get("completed_items")?; - let failed_items: Option = row.try_get("failed_items")?; - Ok(AgentJobProgress { - total_items: usize::try_from(total_items).unwrap_or_default(), - pending_items: usize::try_from(pending_items.unwrap_or_default()).unwrap_or_default(), - running_items: usize::try_from(running_items.unwrap_or_default()).unwrap_or_default(), - completed_items: usize::try_from(completed_items.unwrap_or_default()) - .unwrap_or_default(), - failed_items: usize::try_from(failed_items.unwrap_or_default()).unwrap_or_default(), - }) - } - - async fn ensure_backfill_state_row(&self) -> anyhow::Result<()> { - sqlx::query( - r#" -INSERT INTO backfill_state (id, status, last_watermark, last_success_at, updated_at) -VALUES (?, ?, NULL, NULL, ?) -ON CONFLICT(id) DO NOTHING - "#, - ) - .bind(1_i64) - .bind(crate::BackfillStatus::Pending.as_str()) - .bind(Utc::now().timestamp()) - .execute(self.pool.as_ref()) - .await?; - Ok(()) - } -} - -fn push_log_filters<'a>(builder: &mut QueryBuilder<'a, Sqlite>, query: &'a LogQuery) { - if let Some(level_upper) = query.level_upper.as_ref() { - builder - .push(" AND UPPER(level) = ") - .push_bind(level_upper.as_str()); - } - if let Some(from_ts) = query.from_ts { - builder.push(" AND ts >= ").push_bind(from_ts); - } - if let Some(to_ts) = query.to_ts { - builder.push(" AND ts <= ").push_bind(to_ts); - } - push_like_filters(builder, "module_path", &query.module_like); - push_like_filters(builder, "file", &query.file_like); - let has_thread_filter = !query.thread_ids.is_empty() || query.include_threadless; - if has_thread_filter { - builder.push(" AND ("); - let mut needs_or = false; - for thread_id in &query.thread_ids { - if needs_or { - builder.push(" OR "); - } - builder.push("thread_id = ").push_bind(thread_id.as_str()); - needs_or = true; - } - if query.include_threadless { - if needs_or { - builder.push(" OR "); - } - builder.push("thread_id IS NULL"); - } - builder.push(")"); - } - if let Some(after_id) = query.after_id { - builder.push(" AND id > ").push_bind(after_id); - } - if let Some(search) = query.search.as_ref() { - builder.push(" AND INSTR(message, "); - builder.push_bind(search.as_str()); - builder.push(") > 0"); - } -} - -fn push_like_filters<'a>( - builder: &mut QueryBuilder<'a, Sqlite>, - column: &str, - filters: &'a [String], -) { - if filters.is_empty() { - return; - } - builder.push(" AND ("); - for (idx, filter) in filters.iter().enumerate() { - if idx > 0 { - builder.push(" OR "); - } - builder - .push(column) - .push(" LIKE '%' || ") - .push_bind(filter.as_str()) - .push(" || '%'"); - } - builder.push(")"); -} - -fn extract_dynamic_tools(items: &[RolloutItem]) -> Option>> { - items.iter().find_map(|item| match item { - RolloutItem::SessionMeta(meta_line) => Some(meta_line.meta.dynamic_tools.clone()), - RolloutItem::ResponseItem(_) - | RolloutItem::Compacted(_) - | RolloutItem::TurnContext(_) - | RolloutItem::EventMsg(_) => None, - }) } async fn open_sqlite(path: &Path) -> anyhow::Result { @@ -1650,2902 +198,3 @@ fn should_remove_state_file(file_name: &str, current_name: &str) -> bool { }; !version_suffix.is_empty() && version_suffix.chars().all(|ch| ch.is_ascii_digit()) } - -fn push_thread_filters<'a>( - builder: &mut QueryBuilder<'a, Sqlite>, - archived_only: bool, - allowed_sources: &'a [String], - model_providers: Option<&'a [String]>, - anchor: Option<&crate::Anchor>, - sort_key: SortKey, - search_term: Option<&'a str>, -) { - builder.push(" WHERE 1 = 1"); - if archived_only { - builder.push(" AND archived = 1"); - } else { - builder.push(" AND archived = 0"); - } - builder.push(" AND first_user_message <> ''"); - if !allowed_sources.is_empty() { - builder.push(" AND source IN ("); - let mut separated = builder.separated(", "); - for source in allowed_sources { - separated.push_bind(source); - } - separated.push_unseparated(")"); - } - if let Some(model_providers) = model_providers - && !model_providers.is_empty() - { - builder.push(" AND model_provider IN ("); - let mut separated = builder.separated(", "); - for provider in model_providers { - separated.push_bind(provider); - } - separated.push_unseparated(")"); - } - if let Some(search_term) = search_term { - builder.push(" AND instr(title, "); - builder.push_bind(search_term); - builder.push(") > 0"); - } - if let Some(anchor) = anchor { - let anchor_ts = datetime_to_epoch_seconds(anchor.ts); - let column = match sort_key { - SortKey::CreatedAt => "created_at", - SortKey::UpdatedAt => "updated_at", - }; - builder.push(" AND ("); - builder.push(column); - builder.push(" < "); - builder.push_bind(anchor_ts); - builder.push(" OR ("); - builder.push(column); - builder.push(" = "); - builder.push_bind(anchor_ts); - builder.push(" AND id < "); - builder.push_bind(anchor.id.to_string()); - builder.push("))"); - } -} - -fn push_thread_order_and_limit( - builder: &mut QueryBuilder<'_, Sqlite>, - sort_key: SortKey, - limit: usize, -) { - let order_column = match sort_key { - SortKey::CreatedAt => "created_at", - SortKey::UpdatedAt => "updated_at", - }; - builder.push(" ORDER BY "); - builder.push(order_column); - builder.push(" DESC, id DESC"); - builder.push(" LIMIT "); - builder.push_bind(limit as i64); -} - -#[cfg(test)] -mod tests { - use super::StateRuntime; - use super::ThreadMetadata; - use super::state_db_filename; - use crate::LogEntry; - use crate::LogQuery; - use crate::STATE_DB_FILENAME; - use crate::STATE_DB_VERSION; - use crate::model::Phase2JobClaimOutcome; - use crate::model::Stage1JobClaimOutcome; - use crate::model::Stage1StartupClaimParams; - use chrono::DateTime; - use chrono::Duration; - use chrono::Utc; - use codex_protocol::ThreadId; - use codex_protocol::protocol::AskForApproval; - use codex_protocol::protocol::SandboxPolicy; - use pretty_assertions::assert_eq; - use sqlx::Row; - use std::path::Path; - use std::path::PathBuf; - use std::sync::Arc; - use std::time::SystemTime; - use std::time::UNIX_EPOCH; - use uuid::Uuid; - - fn unique_temp_dir() -> PathBuf { - let nanos = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_or(0, |duration| duration.as_nanos()); - std::env::temp_dir().join(format!( - "codex-state-runtime-test-{nanos}-{}", - Uuid::new_v4() - )) - } - - #[tokio::test] - async fn init_removes_legacy_state_db_files() { - let codex_home = unique_temp_dir(); - tokio::fs::create_dir_all(&codex_home) - .await - .expect("create codex_home"); - - let current_name = state_db_filename(); - let previous_version = STATE_DB_VERSION.saturating_sub(1); - let unversioned_name = format!("{STATE_DB_FILENAME}.sqlite"); - for suffix in ["", "-wal", "-shm", "-journal"] { - let path = codex_home.join(format!("{unversioned_name}{suffix}")); - tokio::fs::write(path, b"legacy") - .await - .expect("write legacy"); - let old_version_path = codex_home.join(format!( - "{STATE_DB_FILENAME}_{previous_version}.sqlite{suffix}" - )); - tokio::fs::write(old_version_path, b"old_version") - .await - .expect("write old version"); - } - let unrelated_path = codex_home.join("state.sqlite_backup"); - tokio::fs::write(&unrelated_path, b"keep") - .await - .expect("write unrelated"); - let numeric_path = codex_home.join("123"); - tokio::fs::write(&numeric_path, b"keep") - .await - .expect("write numeric"); - - let _runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - for suffix in ["", "-wal", "-shm", "-journal"] { - let legacy_path = codex_home.join(format!("{unversioned_name}{suffix}")); - assert_eq!( - tokio::fs::try_exists(&legacy_path) - .await - .expect("check legacy path"), - false - ); - let old_version_path = codex_home.join(format!( - "{STATE_DB_FILENAME}_{previous_version}.sqlite{suffix}" - )); - assert_eq!( - tokio::fs::try_exists(&old_version_path) - .await - .expect("check old version path"), - false - ); - } - assert_eq!( - tokio::fs::try_exists(codex_home.join(current_name)) - .await - .expect("check new db path"), - true - ); - assert_eq!( - tokio::fs::try_exists(&unrelated_path) - .await - .expect("check unrelated path"), - true - ); - assert_eq!( - tokio::fs::try_exists(&numeric_path) - .await - .expect("check numeric path"), - true - ); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn backfill_state_persists_progress_and_completion() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let initial = runtime - .get_backfill_state() - .await - .expect("get initial backfill state"); - assert_eq!(initial.status, crate::BackfillStatus::Pending); - assert_eq!(initial.last_watermark, None); - assert_eq!(initial.last_success_at, None); - - runtime - .mark_backfill_running() - .await - .expect("mark backfill running"); - runtime - .checkpoint_backfill("sessions/2026/01/27/rollout-a.jsonl") - .await - .expect("checkpoint backfill"); - - let running = runtime - .get_backfill_state() - .await - .expect("get running backfill state"); - assert_eq!(running.status, crate::BackfillStatus::Running); - assert_eq!( - running.last_watermark, - Some("sessions/2026/01/27/rollout-a.jsonl".to_string()) - ); - assert_eq!(running.last_success_at, None); - - runtime - .mark_backfill_complete(Some("sessions/2026/01/28/rollout-b.jsonl")) - .await - .expect("mark backfill complete"); - let completed = runtime - .get_backfill_state() - .await - .expect("get completed backfill state"); - assert_eq!(completed.status, crate::BackfillStatus::Complete); - assert_eq!( - completed.last_watermark, - Some("sessions/2026/01/28/rollout-b.jsonl".to_string()) - ); - assert!(completed.last_success_at.is_some()); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn backfill_claim_is_singleton_until_stale_and_blocked_when_complete() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let claimed = runtime - .try_claim_backfill(3600) - .await - .expect("initial backfill claim"); - assert_eq!(claimed, true); - - let duplicate_claim = runtime - .try_claim_backfill(3600) - .await - .expect("duplicate backfill claim"); - assert_eq!(duplicate_claim, false); - - let stale_updated_at = Utc::now().timestamp().saturating_sub(10_000); - sqlx::query( - r#" -UPDATE backfill_state -SET status = ?, updated_at = ? -WHERE id = 1 - "#, - ) - .bind(crate::BackfillStatus::Running.as_str()) - .bind(stale_updated_at) - .execute(runtime.pool.as_ref()) - .await - .expect("force stale backfill lease"); - - let stale_claim = runtime - .try_claim_backfill(10) - .await - .expect("stale backfill claim"); - assert_eq!(stale_claim, true); - - runtime - .mark_backfill_complete(None) - .await - .expect("mark complete"); - let claim_after_complete = runtime - .try_claim_backfill(3600) - .await - .expect("claim after complete"); - assert_eq!(claim_after_complete, false); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn stage1_claim_skips_when_up_to_date() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let metadata = test_thread_metadata(&codex_home, thread_id, codex_home.join("a")); - runtime - .upsert_thread(&metadata) - .await - .expect("upsert thread"); - - let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - - let claim = runtime - .try_claim_stage1_job(thread_id, owner_a, 100, 3600, 64) - .await - .expect("claim stage1 job"); - let ownership_token = match claim { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected claim outcome: {other:?}"), - }; - - assert!( - runtime - .mark_stage1_job_succeeded( - thread_id, - ownership_token.as_str(), - 100, - "raw", - "sum", - None, - ) - .await - .expect("mark stage1 succeeded"), - "stage1 success should finalize for current token" - ); - - let up_to_date = runtime - .try_claim_stage1_job(thread_id, owner_b, 100, 3600, 64) - .await - .expect("claim stage1 up-to-date"); - assert_eq!(up_to_date, Stage1JobClaimOutcome::SkippedUpToDate); - - let needs_rerun = runtime - .try_claim_stage1_job(thread_id, owner_b, 101, 3600, 64) - .await - .expect("claim stage1 newer source"); - assert!( - matches!(needs_rerun, Stage1JobClaimOutcome::Claimed { .. }), - "newer source_updated_at should be claimable" - ); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn stage1_running_stale_can_be_stolen_but_fresh_running_is_skipped() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let cwd = codex_home.join("workspace"); - runtime - .upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd)) - .await - .expect("upsert thread"); - - let claim_a = runtime - .try_claim_stage1_job(thread_id, owner_a, 100, 3600, 64) - .await - .expect("claim a"); - assert!(matches!(claim_a, Stage1JobClaimOutcome::Claimed { .. })); - - let claim_b_fresh = runtime - .try_claim_stage1_job(thread_id, owner_b, 100, 3600, 64) - .await - .expect("claim b fresh"); - assert_eq!(claim_b_fresh, Stage1JobClaimOutcome::SkippedRunning); - - sqlx::query("UPDATE jobs SET lease_until = 0 WHERE kind = 'memory_stage1' AND job_key = ?") - .bind(thread_id.to_string()) - .execute(runtime.pool.as_ref()) - .await - .expect("force stale lease"); - - let claim_b_stale = runtime - .try_claim_stage1_job(thread_id, owner_b, 100, 3600, 64) - .await - .expect("claim b stale"); - assert!(matches!( - claim_b_stale, - Stage1JobClaimOutcome::Claimed { .. } - )); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn stage1_concurrent_claim_for_same_thread_is_conflict_safe() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_id, - codex_home.join("workspace"), - )) - .await - .expect("upsert thread"); - - let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let thread_id_a = thread_id; - let thread_id_b = thread_id; - let runtime_a = Arc::clone(&runtime); - let runtime_b = Arc::clone(&runtime); - let claim_with_retry = |runtime: Arc, - thread_id: ThreadId, - owner: ThreadId| async move { - for attempt in 0..5 { - match runtime - .try_claim_stage1_job(thread_id, owner, 100, 3_600, 64) - .await - { - Ok(outcome) => return outcome, - Err(err) if err.to_string().contains("database is locked") && attempt < 4 => { - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - } - Err(err) => panic!("claim stage1 should not fail: {err}"), - } - } - panic!("claim stage1 should have returned within retry budget") - }; - - let (claim_a, claim_b) = tokio::join!( - claim_with_retry(runtime_a, thread_id_a, owner_a), - claim_with_retry(runtime_b, thread_id_b, owner_b), - ); - - let claim_outcomes = vec![claim_a, claim_b]; - let claimed_count = claim_outcomes - .iter() - .filter(|outcome| matches!(outcome, Stage1JobClaimOutcome::Claimed { .. })) - .count(); - assert_eq!(claimed_count, 1); - assert!( - claim_outcomes.iter().all(|outcome| { - matches!( - outcome, - Stage1JobClaimOutcome::Claimed { .. } | Stage1JobClaimOutcome::SkippedRunning - ) - }), - "unexpected claim outcomes: {claim_outcomes:?}" - ); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn stage1_concurrent_claims_respect_running_cap() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let thread_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_a, - codex_home.join("workspace-a"), - )) - .await - .expect("upsert thread a"); - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_b, - codex_home.join("workspace-b"), - )) - .await - .expect("upsert thread b"); - - let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let runtime_a = Arc::clone(&runtime); - let runtime_b = Arc::clone(&runtime); - - let (claim_a, claim_b) = tokio::join!( - async move { - runtime_a - .try_claim_stage1_job(thread_a, owner_a, 100, 3_600, 1) - .await - .expect("claim stage1 thread a") - }, - async move { - runtime_b - .try_claim_stage1_job(thread_b, owner_b, 101, 3_600, 1) - .await - .expect("claim stage1 thread b") - }, - ); - - let claim_outcomes = vec![claim_a, claim_b]; - let claimed_count = claim_outcomes - .iter() - .filter(|outcome| matches!(outcome, Stage1JobClaimOutcome::Claimed { .. })) - .count(); - assert_eq!(claimed_count, 1); - assert!( - claim_outcomes - .iter() - .any(|outcome| { matches!(outcome, Stage1JobClaimOutcome::SkippedRunning) }), - "one concurrent claim should be throttled by running cap: {claim_outcomes:?}" - ); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn claim_stage1_jobs_filters_by_age_idle_and_current_thread() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let now = Utc::now(); - let fresh_at = now - Duration::hours(1); - let just_under_idle_at = now - Duration::hours(12) + Duration::minutes(1); - let eligible_idle_at = now - Duration::hours(12) - Duration::minutes(1); - let old_at = now - Duration::days(31); - - let current_thread_id = - ThreadId::from_string(&Uuid::new_v4().to_string()).expect("current thread id"); - let fresh_thread_id = - ThreadId::from_string(&Uuid::new_v4().to_string()).expect("fresh thread id"); - let just_under_idle_thread_id = - ThreadId::from_string(&Uuid::new_v4().to_string()).expect("just under idle thread id"); - let eligible_idle_thread_id = - ThreadId::from_string(&Uuid::new_v4().to_string()).expect("eligible idle thread id"); - let old_thread_id = - ThreadId::from_string(&Uuid::new_v4().to_string()).expect("old thread id"); - - let mut current = - test_thread_metadata(&codex_home, current_thread_id, codex_home.join("current")); - current.created_at = now; - current.updated_at = now; - runtime - .upsert_thread(¤t) - .await - .expect("upsert current"); - - let mut fresh = - test_thread_metadata(&codex_home, fresh_thread_id, codex_home.join("fresh")); - fresh.created_at = fresh_at; - fresh.updated_at = fresh_at; - runtime.upsert_thread(&fresh).await.expect("upsert fresh"); - - let mut just_under_idle = test_thread_metadata( - &codex_home, - just_under_idle_thread_id, - codex_home.join("just-under-idle"), - ); - just_under_idle.created_at = just_under_idle_at; - just_under_idle.updated_at = just_under_idle_at; - runtime - .upsert_thread(&just_under_idle) - .await - .expect("upsert just-under-idle"); - - let mut eligible_idle = test_thread_metadata( - &codex_home, - eligible_idle_thread_id, - codex_home.join("eligible-idle"), - ); - eligible_idle.created_at = eligible_idle_at; - eligible_idle.updated_at = eligible_idle_at; - runtime - .upsert_thread(&eligible_idle) - .await - .expect("upsert eligible-idle"); - - let mut old = test_thread_metadata(&codex_home, old_thread_id, codex_home.join("old")); - old.created_at = old_at; - old.updated_at = old_at; - runtime.upsert_thread(&old).await.expect("upsert old"); - - let allowed_sources = vec!["cli".to_string()]; - let claims = runtime - .claim_stage1_jobs_for_startup( - current_thread_id, - Stage1StartupClaimParams { - scan_limit: 1, - max_claimed: 5, - max_age_days: 30, - min_rollout_idle_hours: 12, - allowed_sources: allowed_sources.as_slice(), - lease_seconds: 3600, - }, - ) - .await - .expect("claim stage1 jobs"); - - assert_eq!(claims.len(), 1); - assert_eq!(claims[0].thread.id, eligible_idle_thread_id); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn claim_stage1_jobs_prefilters_threads_with_up_to_date_memory() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let now = Utc::now(); - let eligible_newer_at = now - Duration::hours(13); - let eligible_older_at = now - Duration::hours(14); - - let current_thread_id = - ThreadId::from_string(&Uuid::new_v4().to_string()).expect("current thread id"); - let up_to_date_thread_id = - ThreadId::from_string(&Uuid::new_v4().to_string()).expect("up-to-date thread id"); - let stale_thread_id = - ThreadId::from_string(&Uuid::new_v4().to_string()).expect("stale thread id"); - let worker_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("worker id"); - - let mut current = - test_thread_metadata(&codex_home, current_thread_id, codex_home.join("current")); - current.created_at = now; - current.updated_at = now; - runtime - .upsert_thread(¤t) - .await - .expect("upsert current thread"); - - let mut up_to_date = test_thread_metadata( - &codex_home, - up_to_date_thread_id, - codex_home.join("up-to-date"), - ); - up_to_date.created_at = eligible_newer_at; - up_to_date.updated_at = eligible_newer_at; - runtime - .upsert_thread(&up_to_date) - .await - .expect("upsert up-to-date thread"); - - let up_to_date_claim = runtime - .try_claim_stage1_job( - up_to_date_thread_id, - worker_id, - up_to_date.updated_at.timestamp(), - 3600, - 64, - ) - .await - .expect("claim up-to-date thread for seed"); - let up_to_date_token = match up_to_date_claim { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected seed claim outcome: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded( - up_to_date_thread_id, - up_to_date_token.as_str(), - up_to_date.updated_at.timestamp(), - "raw", - "summary", - None, - ) - .await - .expect("mark up-to-date thread succeeded"), - "seed stage1 success should complete for up-to-date thread" - ); - - let mut stale = - test_thread_metadata(&codex_home, stale_thread_id, codex_home.join("stale")); - stale.created_at = eligible_older_at; - stale.updated_at = eligible_older_at; - runtime - .upsert_thread(&stale) - .await - .expect("upsert stale thread"); - - let allowed_sources = vec!["cli".to_string()]; - let claims = runtime - .claim_stage1_jobs_for_startup( - current_thread_id, - Stage1StartupClaimParams { - scan_limit: 1, - max_claimed: 1, - max_age_days: 30, - min_rollout_idle_hours: 12, - allowed_sources: allowed_sources.as_slice(), - lease_seconds: 3600, - }, - ) - .await - .expect("claim stage1 startup jobs"); - assert_eq!(claims.len(), 1); - assert_eq!(claims[0].thread.id, stale_thread_id); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn claim_stage1_jobs_enforces_global_running_cap() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let current_thread_id = - ThreadId::from_string(&Uuid::new_v4().to_string()).expect("current thread id"); - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - current_thread_id, - codex_home.join("current"), - )) - .await - .expect("upsert current"); - - let now = Utc::now(); - let started_at = now.timestamp(); - let lease_until = started_at + 3600; - let eligible_at = now - Duration::hours(13); - let existing_running = 10usize; - let total_candidates = 80usize; - - for idx in 0..total_candidates { - let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let mut metadata = test_thread_metadata( - &codex_home, - thread_id, - codex_home.join(format!("thread-{idx}")), - ); - metadata.created_at = eligible_at - Duration::seconds(idx as i64); - metadata.updated_at = eligible_at - Duration::seconds(idx as i64); - runtime - .upsert_thread(&metadata) - .await - .expect("upsert thread"); - - if idx < existing_running { - sqlx::query( - r#" -INSERT INTO jobs ( - kind, - job_key, - status, - worker_id, - ownership_token, - started_at, - finished_at, - lease_until, - retry_at, - retry_remaining, - last_error, - input_watermark, - last_success_watermark -) VALUES (?, ?, 'running', ?, ?, ?, NULL, ?, NULL, ?, NULL, ?, NULL) - "#, - ) - .bind("memory_stage1") - .bind(thread_id.to_string()) - .bind(current_thread_id.to_string()) - .bind(Uuid::new_v4().to_string()) - .bind(started_at) - .bind(lease_until) - .bind(3) - .bind(metadata.updated_at.timestamp()) - .execute(runtime.pool.as_ref()) - .await - .expect("seed running stage1 job"); - } - } - - let allowed_sources = vec!["cli".to_string()]; - let claims = runtime - .claim_stage1_jobs_for_startup( - current_thread_id, - Stage1StartupClaimParams { - scan_limit: 200, - max_claimed: 64, - max_age_days: 30, - min_rollout_idle_hours: 12, - allowed_sources: allowed_sources.as_slice(), - lease_seconds: 3600, - }, - ) - .await - .expect("claim stage1 jobs"); - assert_eq!(claims.len(), 54); - - let running_count = sqlx::query( - r#" -SELECT COUNT(*) AS count -FROM jobs -WHERE kind = 'memory_stage1' - AND status = 'running' - AND lease_until IS NOT NULL - AND lease_until > ? - "#, - ) - .bind(Utc::now().timestamp()) - .fetch_one(runtime.pool.as_ref()) - .await - .expect("count running stage1 jobs") - .try_get::("count") - .expect("running count value"); - assert_eq!(running_count, 64); - - let more_claims = runtime - .claim_stage1_jobs_for_startup( - current_thread_id, - Stage1StartupClaimParams { - scan_limit: 200, - max_claimed: 64, - max_age_days: 30, - min_rollout_idle_hours: 12, - allowed_sources: allowed_sources.as_slice(), - lease_seconds: 3600, - }, - ) - .await - .expect("claim stage1 jobs with cap reached"); - assert_eq!(more_claims.len(), 0); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn claim_stage1_jobs_processes_two_full_batches_across_startup_passes() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let current_thread_id = - ThreadId::from_string(&Uuid::new_v4().to_string()).expect("current thread id"); - let mut current = - test_thread_metadata(&codex_home, current_thread_id, codex_home.join("current")); - current.created_at = Utc::now(); - current.updated_at = Utc::now(); - runtime - .upsert_thread(¤t) - .await - .expect("upsert current"); - - let eligible_at = Utc::now() - Duration::hours(13); - for idx in 0..200 { - let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let mut metadata = test_thread_metadata( - &codex_home, - thread_id, - codex_home.join(format!("thread-{idx}")), - ); - metadata.created_at = eligible_at - Duration::seconds(idx as i64); - metadata.updated_at = eligible_at - Duration::seconds(idx as i64); - runtime - .upsert_thread(&metadata) - .await - .expect("upsert eligible thread"); - } - - let allowed_sources = vec!["cli".to_string()]; - let first_claims = runtime - .claim_stage1_jobs_for_startup( - current_thread_id, - Stage1StartupClaimParams { - scan_limit: 5_000, - max_claimed: 64, - max_age_days: 30, - min_rollout_idle_hours: 12, - allowed_sources: allowed_sources.as_slice(), - lease_seconds: 3_600, - }, - ) - .await - .expect("first stage1 startup claim"); - assert_eq!(first_claims.len(), 64); - - for claim in first_claims { - assert!( - runtime - .mark_stage1_job_succeeded( - claim.thread.id, - claim.ownership_token.as_str(), - claim.thread.updated_at.timestamp(), - "raw", - "summary", - None, - ) - .await - .expect("mark first-batch stage1 success"), - "first batch stage1 completion should succeed" - ); - } - - let second_claims = runtime - .claim_stage1_jobs_for_startup( - current_thread_id, - Stage1StartupClaimParams { - scan_limit: 5_000, - max_claimed: 64, - max_age_days: 30, - min_rollout_idle_hours: 12, - allowed_sources: allowed_sources.as_slice(), - lease_seconds: 3_600, - }, - ) - .await - .expect("second stage1 startup claim"); - assert_eq!(second_claims.len(), 64); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn stage1_output_cascades_on_thread_delete() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let cwd = codex_home.join("workspace"); - runtime - .upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd)) - .await - .expect("upsert thread"); - - let claim = runtime - .try_claim_stage1_job(thread_id, owner, 100, 3600, 64) - .await - .expect("claim stage1"); - let ownership_token = match claim { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected claim outcome: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded( - thread_id, - ownership_token.as_str(), - 100, - "raw", - "sum", - None, - ) - .await - .expect("mark stage1 succeeded"), - "mark stage1 succeeded should write stage1_outputs" - ); - - let count_before = - sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?") - .bind(thread_id.to_string()) - .fetch_one(runtime.pool.as_ref()) - .await - .expect("count before delete") - .try_get::("count") - .expect("count value"); - assert_eq!(count_before, 1); - - sqlx::query("DELETE FROM threads WHERE id = ?") - .bind(thread_id.to_string()) - .execute(runtime.pool.as_ref()) - .await - .expect("delete thread"); - - let count_after = - sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?") - .bind(thread_id.to_string()) - .fetch_one(runtime.pool.as_ref()) - .await - .expect("count after delete") - .try_get::("count") - .expect("count value"); - assert_eq!(count_after, 0); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn mark_stage1_job_succeeded_no_output_skips_phase2_when_output_was_already_absent() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_id, - codex_home.join("workspace"), - )) - .await - .expect("upsert thread"); - - let claim = runtime - .try_claim_stage1_job(thread_id, owner, 100, 3600, 64) - .await - .expect("claim stage1"); - let ownership_token = match claim { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected claim outcome: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded_no_output(thread_id, ownership_token.as_str()) - .await - .expect("mark stage1 succeeded without output"), - "stage1 no-output success should complete the job" - ); - - let output_row_count = - sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?") - .bind(thread_id.to_string()) - .fetch_one(runtime.pool.as_ref()) - .await - .expect("load stage1 output count") - .try_get::("count") - .expect("stage1 output count"); - assert_eq!( - output_row_count, 0, - "stage1 no-output success should not persist empty stage1 outputs" - ); - - let up_to_date = runtime - .try_claim_stage1_job(thread_id, owner_b, 100, 3600, 64) - .await - .expect("claim stage1 up-to-date"); - assert_eq!(up_to_date, Stage1JobClaimOutcome::SkippedUpToDate); - - let global_job_row_count = sqlx::query("SELECT COUNT(*) AS count FROM jobs WHERE kind = ?") - .bind("memory_consolidate_global") - .fetch_one(runtime.pool.as_ref()) - .await - .expect("load phase2 job row count") - .try_get::("count") - .expect("phase2 job row count"); - assert_eq!( - global_job_row_count, 0, - "no-output without an existing stage1 output should not enqueue phase2" - ); - - let claim_phase2 = runtime - .try_claim_global_phase2_job(owner, 3600) - .await - .expect("claim phase2"); - assert_eq!( - claim_phase2, - Phase2JobClaimOutcome::SkippedNotDirty, - "phase2 should remain clean when no-output deleted nothing" - ); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn mark_stage1_job_succeeded_no_output_enqueues_phase2_when_deleting_output() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_id, - codex_home.join("workspace"), - )) - .await - .expect("upsert thread"); - - let first_claim = runtime - .try_claim_stage1_job(thread_id, owner, 100, 3600, 64) - .await - .expect("claim initial stage1"); - let first_token = match first_claim { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected initial stage1 claim outcome: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded(thread_id, first_token.as_str(), 100, "raw", "sum", None) - .await - .expect("mark initial stage1 succeeded"), - "initial stage1 success should create stage1 output" - ); - - let phase2_claim = runtime - .try_claim_global_phase2_job(owner, 3600) - .await - .expect("claim phase2 after initial output"); - let (phase2_token, phase2_input_watermark) = match phase2_claim { - Phase2JobClaimOutcome::Claimed { - ownership_token, - input_watermark, - } => (ownership_token, input_watermark), - other => panic!("unexpected phase2 claim after initial output: {other:?}"), - }; - assert_eq!(phase2_input_watermark, 100); - assert!( - runtime - .mark_global_phase2_job_succeeded( - phase2_token.as_str(), - phase2_input_watermark, - &[], - ) - .await - .expect("mark initial phase2 succeeded"), - "initial phase2 success should clear global dirty state" - ); - - let no_output_claim = runtime - .try_claim_stage1_job(thread_id, owner_b, 101, 3600, 64) - .await - .expect("claim stage1 for no-output delete"); - let no_output_token = match no_output_claim { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected no-output stage1 claim outcome: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded_no_output(thread_id, no_output_token.as_str()) - .await - .expect("mark stage1 no-output after existing output"), - "no-output should succeed when deleting an existing stage1 output" - ); - - let output_row_count = - sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?") - .bind(thread_id.to_string()) - .fetch_one(runtime.pool.as_ref()) - .await - .expect("load stage1 output count after delete") - .try_get::("count") - .expect("stage1 output count"); - assert_eq!(output_row_count, 0); - - let claim_phase2 = runtime - .try_claim_global_phase2_job(owner, 3600) - .await - .expect("claim phase2 after no-output deletion"); - let (phase2_token, phase2_input_watermark) = match claim_phase2 { - Phase2JobClaimOutcome::Claimed { - ownership_token, - input_watermark, - } => (ownership_token, input_watermark), - other => panic!("unexpected phase2 claim after no-output deletion: {other:?}"), - }; - assert_eq!(phase2_input_watermark, 101); - assert!( - runtime - .mark_global_phase2_job_succeeded( - phase2_token.as_str(), - phase2_input_watermark, - &[], - ) - .await - .expect("mark phase2 succeeded after no-output delete") - ); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn stage1_retry_exhaustion_does_not_block_newer_watermark() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_id, - codex_home.join("workspace"), - )) - .await - .expect("upsert thread"); - - for attempt in 0..3 { - let claim = runtime - .try_claim_stage1_job(thread_id, owner, 100, 3_600, 64) - .await - .expect("claim stage1 for retry exhaustion"); - let ownership_token = match claim { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!( - "attempt {} should claim stage1 before retries are exhausted: {other:?}", - attempt + 1 - ), - }; - assert!( - runtime - .mark_stage1_job_failed(thread_id, ownership_token.as_str(), "boom", 0) - .await - .expect("mark stage1 failed"), - "attempt {} should decrement retry budget", - attempt + 1 - ); - } - - let exhausted_claim = runtime - .try_claim_stage1_job(thread_id, owner, 100, 3_600, 64) - .await - .expect("claim stage1 after retry exhaustion"); - assert_eq!( - exhausted_claim, - Stage1JobClaimOutcome::SkippedRetryExhausted - ); - - let newer_source_claim = runtime - .try_claim_stage1_job(thread_id, owner, 101, 3_600, 64) - .await - .expect("claim stage1 with newer source watermark"); - assert!( - matches!(newer_source_claim, Stage1JobClaimOutcome::Claimed { .. }), - "newer source watermark should reset retry budget and be claimable" - ); - - let job_row = sqlx::query( - "SELECT retry_remaining, input_watermark FROM jobs WHERE kind = ? AND job_key = ?", - ) - .bind("memory_stage1") - .bind(thread_id.to_string()) - .fetch_one(runtime.pool.as_ref()) - .await - .expect("load stage1 job row after newer-source claim"); - assert_eq!( - job_row - .try_get::("retry_remaining") - .expect("retry_remaining"), - 3 - ); - assert_eq!( - job_row - .try_get::("input_watermark") - .expect("input_watermark"), - 101 - ); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn phase2_global_consolidation_reruns_when_watermark_advances() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - - runtime - .enqueue_global_consolidation(100) - .await - .expect("enqueue global consolidation"); - - let claim = runtime - .try_claim_global_phase2_job(owner, 3600) - .await - .expect("claim phase2"); - let (ownership_token, input_watermark) = match claim { - Phase2JobClaimOutcome::Claimed { - ownership_token, - input_watermark, - } => (ownership_token, input_watermark), - other => panic!("unexpected phase2 claim outcome: {other:?}"), - }; - assert!( - runtime - .mark_global_phase2_job_succeeded(ownership_token.as_str(), input_watermark, &[],) - .await - .expect("mark phase2 succeeded"), - "phase2 success should finalize for current token" - ); - - let claim_up_to_date = runtime - .try_claim_global_phase2_job(owner, 3600) - .await - .expect("claim phase2 up-to-date"); - assert_eq!(claim_up_to_date, Phase2JobClaimOutcome::SkippedNotDirty); - - runtime - .enqueue_global_consolidation(101) - .await - .expect("enqueue global consolidation again"); - - let claim_rerun = runtime - .try_claim_global_phase2_job(owner, 3600) - .await - .expect("claim phase2 rerun"); - assert!( - matches!(claim_rerun, Phase2JobClaimOutcome::Claimed { .. }), - "advanced watermark should be claimable" - ); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn list_stage1_outputs_for_global_returns_latest_outputs() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let thread_id_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_id_a, - codex_home.join("workspace-a"), - )) - .await - .expect("upsert thread a"); - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_id_b, - codex_home.join("workspace-b"), - )) - .await - .expect("upsert thread b"); - - let claim = runtime - .try_claim_stage1_job(thread_id_a, owner, 100, 3600, 64) - .await - .expect("claim stage1 a"); - let ownership_token = match claim { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected stage1 claim outcome: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded( - thread_id_a, - ownership_token.as_str(), - 100, - "raw memory a", - "summary a", - None, - ) - .await - .expect("mark stage1 succeeded a"), - "stage1 success should persist output a" - ); - - let claim = runtime - .try_claim_stage1_job(thread_id_b, owner, 101, 3600, 64) - .await - .expect("claim stage1 b"); - let ownership_token = match claim { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected stage1 claim outcome: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded( - thread_id_b, - ownership_token.as_str(), - 101, - "raw memory b", - "summary b", - Some("rollout-b"), - ) - .await - .expect("mark stage1 succeeded b"), - "stage1 success should persist output b" - ); - - let outputs = runtime - .list_stage1_outputs_for_global(10) - .await - .expect("list stage1 outputs for global"); - assert_eq!(outputs.len(), 2); - assert_eq!(outputs[0].thread_id, thread_id_b); - assert_eq!(outputs[0].rollout_summary, "summary b"); - assert_eq!(outputs[0].rollout_slug.as_deref(), Some("rollout-b")); - assert_eq!(outputs[0].cwd, codex_home.join("workspace-b")); - assert_eq!(outputs[1].thread_id, thread_id_a); - assert_eq!(outputs[1].rollout_summary, "summary a"); - assert_eq!(outputs[1].rollout_slug, None); - assert_eq!(outputs[1].cwd, codex_home.join("workspace-a")); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn list_stage1_outputs_for_global_skips_empty_payloads() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id_non_empty = - ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let thread_id_empty = - ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_id_non_empty, - codex_home.join("workspace-non-empty"), - )) - .await - .expect("upsert non-empty thread"); - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_id_empty, - codex_home.join("workspace-empty"), - )) - .await - .expect("upsert empty thread"); - - sqlx::query( - r#" -INSERT INTO stage1_outputs (thread_id, source_updated_at, raw_memory, rollout_summary, generated_at) -VALUES (?, ?, ?, ?, ?) - "#, - ) - .bind(thread_id_non_empty.to_string()) - .bind(100_i64) - .bind("raw memory") - .bind("summary") - .bind(100_i64) - .execute(runtime.pool.as_ref()) - .await - .expect("insert non-empty stage1 output"); - sqlx::query( - r#" -INSERT INTO stage1_outputs (thread_id, source_updated_at, raw_memory, rollout_summary, generated_at) -VALUES (?, ?, ?, ?, ?) - "#, - ) - .bind(thread_id_empty.to_string()) - .bind(101_i64) - .bind("") - .bind("") - .bind(101_i64) - .execute(runtime.pool.as_ref()) - .await - .expect("insert empty stage1 output"); - - let outputs = runtime - .list_stage1_outputs_for_global(1) - .await - .expect("list stage1 outputs for global"); - assert_eq!(outputs.len(), 1); - assert_eq!(outputs[0].thread_id, thread_id_non_empty); - assert_eq!(outputs[0].rollout_summary, "summary"); - assert_eq!(outputs[0].cwd, codex_home.join("workspace-non-empty")); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn get_phase2_input_selection_reports_added_retained_and_removed_rows() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let thread_id_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let thread_id_c = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - - for (thread_id, workspace) in [ - (thread_id_a, "workspace-a"), - (thread_id_b, "workspace-b"), - (thread_id_c, "workspace-c"), - ] { - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_id, - codex_home.join(workspace), - )) - .await - .expect("upsert thread"); - } - - for (thread_id, updated_at, slug) in [ - (thread_id_a, 100, Some("rollout-a")), - (thread_id_b, 101, Some("rollout-b")), - (thread_id_c, 102, Some("rollout-c")), - ] { - let claim = runtime - .try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64) - .await - .expect("claim stage1"); - let ownership_token = match claim { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected stage1 claim outcome: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded( - thread_id, - ownership_token.as_str(), - updated_at, - &format!("raw-{updated_at}"), - &format!("summary-{updated_at}"), - slug, - ) - .await - .expect("mark stage1 succeeded"), - "stage1 success should persist output" - ); - } - - let claim = runtime - .try_claim_global_phase2_job(owner, 3600) - .await - .expect("claim phase2"); - let (ownership_token, input_watermark) = match claim { - Phase2JobClaimOutcome::Claimed { - ownership_token, - input_watermark, - } => (ownership_token, input_watermark), - other => panic!("unexpected phase2 claim outcome: {other:?}"), - }; - assert_eq!(input_watermark, 102); - let selected_outputs = runtime - .list_stage1_outputs_for_global(10) - .await - .expect("list stage1 outputs for global") - .into_iter() - .filter(|output| output.thread_id == thread_id_c || output.thread_id == thread_id_a) - .collect::>(); - assert!( - runtime - .mark_global_phase2_job_succeeded( - ownership_token.as_str(), - input_watermark, - &selected_outputs, - ) - .await - .expect("mark phase2 success with selection"), - "phase2 success should persist selected rows" - ); - - let selection = runtime - .get_phase2_input_selection(2) - .await - .expect("load phase2 input selection"); - - assert_eq!(selection.selected.len(), 2); - assert_eq!(selection.previous_selected.len(), 2); - assert_eq!(selection.selected[0].thread_id, thread_id_c); - assert_eq!( - selection.selected[0].rollout_path, - codex_home.join(format!("rollout-{thread_id_c}.jsonl")) - ); - assert_eq!(selection.selected[1].thread_id, thread_id_b); - assert_eq!(selection.retained_thread_ids, vec![thread_id_c]); - - assert_eq!(selection.removed.len(), 1); - assert_eq!(selection.removed[0].thread_id, thread_id_a); - assert_eq!( - selection.removed[0].rollout_slug.as_deref(), - Some("rollout-a") - ); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn get_phase2_input_selection_treats_regenerated_selected_rows_as_added() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_id, - codex_home.join("workspace"), - )) - .await - .expect("upsert thread"); - - let first_claim = runtime - .try_claim_stage1_job(thread_id, owner, 100, 3600, 64) - .await - .expect("claim initial stage1"); - let first_token = match first_claim { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected stage1 claim outcome: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded( - thread_id, - first_token.as_str(), - 100, - "raw-100", - "summary-100", - Some("rollout-100"), - ) - .await - .expect("mark initial stage1 success"), - "initial stage1 success should persist output" - ); - - let phase2_claim = runtime - .try_claim_global_phase2_job(owner, 3600) - .await - .expect("claim phase2"); - let (phase2_token, input_watermark) = match phase2_claim { - Phase2JobClaimOutcome::Claimed { - ownership_token, - input_watermark, - } => (ownership_token, input_watermark), - other => panic!("unexpected phase2 claim outcome: {other:?}"), - }; - let selected_outputs = runtime - .list_stage1_outputs_for_global(1) - .await - .expect("list selected outputs"); - assert!( - runtime - .mark_global_phase2_job_succeeded( - phase2_token.as_str(), - input_watermark, - &selected_outputs, - ) - .await - .expect("mark phase2 success"), - "phase2 success should persist selected rows" - ); - - let refreshed_claim = runtime - .try_claim_stage1_job(thread_id, owner, 101, 3600, 64) - .await - .expect("claim refreshed stage1"); - let refreshed_token = match refreshed_claim { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected stage1 claim outcome: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded( - thread_id, - refreshed_token.as_str(), - 101, - "raw-101", - "summary-101", - Some("rollout-101"), - ) - .await - .expect("mark refreshed stage1 success"), - "refreshed stage1 success should persist output" - ); - - let selection = runtime - .get_phase2_input_selection(1) - .await - .expect("load phase2 input selection"); - assert_eq!(selection.selected.len(), 1); - assert_eq!(selection.previous_selected.len(), 1); - assert_eq!(selection.selected[0].thread_id, thread_id); - assert_eq!(selection.selected[0].source_updated_at.timestamp(), 101); - assert!(selection.retained_thread_ids.is_empty()); - assert!(selection.removed.is_empty()); - - let (selected_for_phase2, selected_for_phase2_source_updated_at) = - sqlx::query_as::<_, (i64, Option)>( - "SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?", - ) - .bind(thread_id.to_string()) - .fetch_one(runtime.pool.as_ref()) - .await - .expect("load selected_for_phase2"); - assert_eq!(selected_for_phase2, 1); - assert_eq!(selected_for_phase2_source_updated_at, Some(100)); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn get_phase2_input_selection_reports_regenerated_previous_selection_as_removed() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread a"); - let thread_id_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread b"); - let thread_id_c = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread c"); - let thread_id_d = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread d"); - let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - - for (thread_id, workspace) in [ - (thread_id_a, "workspace-a"), - (thread_id_b, "workspace-b"), - (thread_id_c, "workspace-c"), - (thread_id_d, "workspace-d"), - ] { - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_id, - codex_home.join(workspace), - )) - .await - .expect("upsert thread"); - } - - for (thread_id, updated_at, slug) in [ - (thread_id_a, 100, Some("rollout-a-100")), - (thread_id_b, 101, Some("rollout-b-101")), - (thread_id_c, 99, Some("rollout-c-99")), - (thread_id_d, 98, Some("rollout-d-98")), - ] { - let claim = runtime - .try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64) - .await - .expect("claim initial stage1"); - let ownership_token = match claim { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected stage1 claim outcome: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded( - thread_id, - ownership_token.as_str(), - updated_at, - &format!("raw-{updated_at}"), - &format!("summary-{updated_at}"), - slug, - ) - .await - .expect("mark stage1 succeeded"), - "stage1 success should persist output" - ); - } - - let phase2_claim = runtime - .try_claim_global_phase2_job(owner, 3600) - .await - .expect("claim phase2"); - let (phase2_token, input_watermark) = match phase2_claim { - Phase2JobClaimOutcome::Claimed { - ownership_token, - input_watermark, - } => (ownership_token, input_watermark), - other => panic!("unexpected phase2 claim outcome: {other:?}"), - }; - let selected_outputs = runtime - .list_stage1_outputs_for_global(2) - .await - .expect("list selected outputs"); - assert_eq!( - selected_outputs - .iter() - .map(|output| output.thread_id) - .collect::>(), - vec![thread_id_b, thread_id_a] - ); - assert!( - runtime - .mark_global_phase2_job_succeeded( - phase2_token.as_str(), - input_watermark, - &selected_outputs, - ) - .await - .expect("mark phase2 success"), - "phase2 success should persist selected rows" - ); - - for (thread_id, updated_at, slug) in [ - (thread_id_a, 102, Some("rollout-a-102")), - (thread_id_c, 103, Some("rollout-c-103")), - (thread_id_d, 104, Some("rollout-d-104")), - ] { - let claim = runtime - .try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64) - .await - .expect("claim refreshed stage1"); - let ownership_token = match claim { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected stage1 claim outcome: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded( - thread_id, - ownership_token.as_str(), - updated_at, - &format!("raw-{updated_at}"), - &format!("summary-{updated_at}"), - slug, - ) - .await - .expect("mark refreshed stage1 success"), - "refreshed stage1 success should persist output" - ); - } - - let selection = runtime - .get_phase2_input_selection(2) - .await - .expect("load phase2 input selection"); - assert_eq!( - selection - .selected - .iter() - .map(|output| output.thread_id) - .collect::>(), - vec![thread_id_d, thread_id_c] - ); - assert_eq!( - selection - .previous_selected - .iter() - .map(|output| output.thread_id) - .collect::>(), - vec![thread_id_a, thread_id_b] - ); - assert!(selection.retained_thread_ids.is_empty()); - assert_eq!( - selection - .removed - .iter() - .map(|output| (output.thread_id, output.source_updated_at.timestamp())) - .collect::>(), - vec![(thread_id_a, 102), (thread_id_b, 101)] - ); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn mark_global_phase2_job_succeeded_updates_selected_snapshot_timestamp() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_id, - codex_home.join("workspace"), - )) - .await - .expect("upsert thread"); - - let initial_claim = runtime - .try_claim_stage1_job(thread_id, owner, 100, 3600, 64) - .await - .expect("claim initial stage1"); - let initial_token = match initial_claim { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected stage1 claim outcome: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded( - thread_id, - initial_token.as_str(), - 100, - "raw-100", - "summary-100", - Some("rollout-100"), - ) - .await - .expect("mark initial stage1 success"), - "initial stage1 success should persist output" - ); - - let first_phase2_claim = runtime - .try_claim_global_phase2_job(owner, 3600) - .await - .expect("claim first phase2"); - let (first_phase2_token, first_input_watermark) = match first_phase2_claim { - Phase2JobClaimOutcome::Claimed { - ownership_token, - input_watermark, - } => (ownership_token, input_watermark), - other => panic!("unexpected first phase2 claim outcome: {other:?}"), - }; - let first_selected_outputs = runtime - .list_stage1_outputs_for_global(1) - .await - .expect("list first selected outputs"); - assert!( - runtime - .mark_global_phase2_job_succeeded( - first_phase2_token.as_str(), - first_input_watermark, - &first_selected_outputs, - ) - .await - .expect("mark first phase2 success"), - "first phase2 success should persist selected rows" - ); - - let refreshed_claim = runtime - .try_claim_stage1_job(thread_id, owner, 101, 3600, 64) - .await - .expect("claim refreshed stage1"); - let refreshed_token = match refreshed_claim { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected refreshed stage1 claim outcome: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded( - thread_id, - refreshed_token.as_str(), - 101, - "raw-101", - "summary-101", - Some("rollout-101"), - ) - .await - .expect("mark refreshed stage1 success"), - "refreshed stage1 success should persist output" - ); - - let second_phase2_claim = runtime - .try_claim_global_phase2_job(owner, 3600) - .await - .expect("claim second phase2"); - let (second_phase2_token, second_input_watermark) = match second_phase2_claim { - Phase2JobClaimOutcome::Claimed { - ownership_token, - input_watermark, - } => (ownership_token, input_watermark), - other => panic!("unexpected second phase2 claim outcome: {other:?}"), - }; - let second_selected_outputs = runtime - .list_stage1_outputs_for_global(1) - .await - .expect("list second selected outputs"); - assert_eq!( - second_selected_outputs[0].source_updated_at.timestamp(), - 101 - ); - assert!( - runtime - .mark_global_phase2_job_succeeded( - second_phase2_token.as_str(), - second_input_watermark, - &second_selected_outputs, - ) - .await - .expect("mark second phase2 success"), - "second phase2 success should persist selected rows" - ); - - let selection = runtime - .get_phase2_input_selection(1) - .await - .expect("load phase2 input selection after refresh"); - assert_eq!(selection.retained_thread_ids, vec![thread_id]); - - let (selected_for_phase2, selected_for_phase2_source_updated_at) = - sqlx::query_as::<_, (i64, Option)>( - "SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?", - ) - .bind(thread_id.to_string()) - .fetch_one(runtime.pool.as_ref()) - .await - .expect("load selected snapshot after phase2"); - assert_eq!(selected_for_phase2, 1); - assert_eq!(selected_for_phase2_source_updated_at, Some(101)); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn mark_global_phase2_job_succeeded_only_marks_exact_selected_snapshots() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_id, - codex_home.join("workspace"), - )) - .await - .expect("upsert thread"); - - let initial_claim = runtime - .try_claim_stage1_job(thread_id, owner, 100, 3600, 64) - .await - .expect("claim initial stage1"); - let initial_token = match initial_claim { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected stage1 claim outcome: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded( - thread_id, - initial_token.as_str(), - 100, - "raw-100", - "summary-100", - Some("rollout-100"), - ) - .await - .expect("mark initial stage1 success"), - "initial stage1 success should persist output" - ); - - let phase2_claim = runtime - .try_claim_global_phase2_job(owner, 3600) - .await - .expect("claim phase2"); - let (phase2_token, input_watermark) = match phase2_claim { - Phase2JobClaimOutcome::Claimed { - ownership_token, - input_watermark, - } => (ownership_token, input_watermark), - other => panic!("unexpected phase2 claim outcome: {other:?}"), - }; - let selected_outputs = runtime - .list_stage1_outputs_for_global(1) - .await - .expect("list selected outputs"); - assert_eq!(selected_outputs[0].source_updated_at.timestamp(), 100); - - let refreshed_claim = runtime - .try_claim_stage1_job(thread_id, owner, 101, 3600, 64) - .await - .expect("claim refreshed stage1"); - let refreshed_token = match refreshed_claim { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected stage1 claim outcome: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded( - thread_id, - refreshed_token.as_str(), - 101, - "raw-101", - "summary-101", - Some("rollout-101"), - ) - .await - .expect("mark refreshed stage1 success"), - "refreshed stage1 success should persist output" - ); - - assert!( - runtime - .mark_global_phase2_job_succeeded( - phase2_token.as_str(), - input_watermark, - &selected_outputs, - ) - .await - .expect("mark phase2 success"), - "phase2 success should still complete" - ); - - let (selected_for_phase2, selected_for_phase2_source_updated_at) = - sqlx::query_as::<_, (i64, Option)>( - "SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?", - ) - .bind(thread_id.to_string()) - .fetch_one(runtime.pool.as_ref()) - .await - .expect("load selected_for_phase2"); - assert_eq!(selected_for_phase2, 0); - assert_eq!(selected_for_phase2_source_updated_at, None); - - let selection = runtime - .get_phase2_input_selection(1) - .await - .expect("load phase2 input selection"); - assert_eq!(selection.selected.len(), 1); - assert_eq!(selection.selected[0].source_updated_at.timestamp(), 101); - assert!(selection.retained_thread_ids.is_empty()); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn record_stage1_output_usage_updates_usage_metadata() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id a"); - let thread_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id b"); - let missing = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("missing id"); - let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_a, - codex_home.join("workspace-a"), - )) - .await - .expect("upsert thread a"); - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_b, - codex_home.join("workspace-b"), - )) - .await - .expect("upsert thread b"); - - let claim_a = runtime - .try_claim_stage1_job(thread_a, owner, 100, 3600, 64) - .await - .expect("claim stage1 a"); - let token_a = match claim_a { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected stage1 claim outcome for a: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded(thread_a, token_a.as_str(), 100, "raw a", "sum a", None) - .await - .expect("mark stage1 succeeded a") - ); - - let claim_b = runtime - .try_claim_stage1_job(thread_b, owner, 101, 3600, 64) - .await - .expect("claim stage1 b"); - let token_b = match claim_b { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected stage1 claim outcome for b: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded(thread_b, token_b.as_str(), 101, "raw b", "sum b", None) - .await - .expect("mark stage1 succeeded b") - ); - - let updated_rows = runtime - .record_stage1_output_usage(&[thread_a, thread_a, thread_b, missing]) - .await - .expect("record stage1 output usage"); - assert_eq!(updated_rows, 3); - - let row_a = - sqlx::query("SELECT usage_count, last_usage FROM stage1_outputs WHERE thread_id = ?") - .bind(thread_a.to_string()) - .fetch_one(runtime.pool.as_ref()) - .await - .expect("load stage1 usage row a"); - let row_b = - sqlx::query("SELECT usage_count, last_usage FROM stage1_outputs WHERE thread_id = ?") - .bind(thread_b.to_string()) - .fetch_one(runtime.pool.as_ref()) - .await - .expect("load stage1 usage row b"); - - assert_eq!( - row_a - .try_get::("usage_count") - .expect("usage_count a"), - 2 - ); - assert_eq!( - row_b - .try_get::("usage_count") - .expect("usage_count b"), - 1 - ); - - let last_usage_a = row_a.try_get::("last_usage").expect("last_usage a"); - let last_usage_b = row_b.try_get::("last_usage").expect("last_usage b"); - assert_eq!(last_usage_a, last_usage_b); - assert!(last_usage_a > 0); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn mark_stage1_job_succeeded_enqueues_global_consolidation() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id a"); - let thread_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id b"); - let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_a, - codex_home.join("workspace-a"), - )) - .await - .expect("upsert thread a"); - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - thread_b, - codex_home.join("workspace-b"), - )) - .await - .expect("upsert thread b"); - - let claim_a = runtime - .try_claim_stage1_job(thread_a, owner, 100, 3600, 64) - .await - .expect("claim stage1 a"); - let token_a = match claim_a { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected stage1 claim outcome for thread a: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded( - thread_a, - token_a.as_str(), - 100, - "raw-a", - "summary-a", - None, - ) - .await - .expect("mark stage1 succeeded a"), - "stage1 success should persist output for thread a" - ); - - let claim_b = runtime - .try_claim_stage1_job(thread_b, owner, 101, 3600, 64) - .await - .expect("claim stage1 b"); - let token_b = match claim_b { - Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected stage1 claim outcome for thread b: {other:?}"), - }; - assert!( - runtime - .mark_stage1_job_succeeded( - thread_b, - token_b.as_str(), - 101, - "raw-b", - "summary-b", - None, - ) - .await - .expect("mark stage1 succeeded b"), - "stage1 success should persist output for thread b" - ); - - let claim = runtime - .try_claim_global_phase2_job(owner, 3600) - .await - .expect("claim global consolidation"); - let input_watermark = match claim { - Phase2JobClaimOutcome::Claimed { - input_watermark, .. - } => input_watermark, - other => panic!("unexpected global consolidation claim outcome: {other:?}"), - }; - assert_eq!(input_watermark, 101); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn phase2_global_lock_allows_only_one_fresh_runner() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - runtime - .enqueue_global_consolidation(200) - .await - .expect("enqueue global consolidation"); - - let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner a"); - let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner b"); - - let running_claim = runtime - .try_claim_global_phase2_job(owner_a, 3600) - .await - .expect("claim global lock"); - assert!( - matches!(running_claim, Phase2JobClaimOutcome::Claimed { .. }), - "first owner should claim global lock" - ); - - let second_claim = runtime - .try_claim_global_phase2_job(owner_b, 3600) - .await - .expect("claim global lock from second owner"); - assert_eq!(second_claim, Phase2JobClaimOutcome::SkippedRunning); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn phase2_global_lock_stale_lease_allows_takeover() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - runtime - .enqueue_global_consolidation(300) - .await - .expect("enqueue global consolidation"); - - let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner a"); - let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner b"); - - let initial_claim = runtime - .try_claim_global_phase2_job(owner_a, 3600) - .await - .expect("claim initial global lock"); - let token_a = match initial_claim { - Phase2JobClaimOutcome::Claimed { - ownership_token, .. - } => ownership_token, - other => panic!("unexpected initial claim outcome: {other:?}"), - }; - - sqlx::query("UPDATE jobs SET lease_until = ? WHERE kind = ? AND job_key = ?") - .bind(Utc::now().timestamp() - 1) - .bind("memory_consolidate_global") - .bind("global") - .execute(runtime.pool.as_ref()) - .await - .expect("expire global consolidation lease"); - - let takeover_claim = runtime - .try_claim_global_phase2_job(owner_b, 3600) - .await - .expect("claim stale global lock"); - let (token_b, input_watermark) = match takeover_claim { - Phase2JobClaimOutcome::Claimed { - ownership_token, - input_watermark, - } => (ownership_token, input_watermark), - other => panic!("unexpected takeover claim outcome: {other:?}"), - }; - assert_ne!(token_a, token_b); - assert_eq!(input_watermark, 300); - - assert_eq!( - runtime - .mark_global_phase2_job_succeeded(token_a.as_str(), 300, &[]) - .await - .expect("mark stale owner success result"), - false, - "stale owner should lose finalization ownership after takeover" - ); - assert!( - runtime - .mark_global_phase2_job_succeeded(token_b.as_str(), 300, &[]) - .await - .expect("mark takeover owner success"), - "takeover owner should finalize consolidation" - ); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn phase2_backfilled_inputs_below_last_success_still_become_dirty() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - runtime - .enqueue_global_consolidation(500) - .await - .expect("enqueue initial consolidation"); - let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner a"); - let claim_a = runtime - .try_claim_global_phase2_job(owner_a, 3_600) - .await - .expect("claim initial consolidation"); - let token_a = match claim_a { - Phase2JobClaimOutcome::Claimed { - ownership_token, - input_watermark, - } => { - assert_eq!(input_watermark, 500); - ownership_token - } - other => panic!("unexpected initial phase2 claim outcome: {other:?}"), - }; - assert!( - runtime - .mark_global_phase2_job_succeeded(token_a.as_str(), 500, &[]) - .await - .expect("mark initial phase2 success"), - "initial phase2 success should finalize" - ); - - runtime - .enqueue_global_consolidation(400) - .await - .expect("enqueue backfilled consolidation"); - - let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner b"); - let claim_b = runtime - .try_claim_global_phase2_job(owner_b, 3_600) - .await - .expect("claim backfilled consolidation"); - match claim_b { - Phase2JobClaimOutcome::Claimed { - input_watermark, .. - } => { - assert!( - input_watermark > 500, - "backfilled enqueue should advance dirty watermark beyond last success" - ); - } - other => panic!("unexpected backfilled phase2 claim outcome: {other:?}"), - } - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn phase2_failure_fallback_updates_unowned_running_job() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - runtime - .enqueue_global_consolidation(400) - .await - .expect("enqueue global consolidation"); - - let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner"); - let claim = runtime - .try_claim_global_phase2_job(owner, 3_600) - .await - .expect("claim global consolidation"); - let ownership_token = match claim { - Phase2JobClaimOutcome::Claimed { - ownership_token, .. - } => ownership_token, - other => panic!("unexpected claim outcome: {other:?}"), - }; - - sqlx::query("UPDATE jobs SET ownership_token = NULL WHERE kind = ? AND job_key = ?") - .bind("memory_consolidate_global") - .bind("global") - .execute(runtime.pool.as_ref()) - .await - .expect("clear ownership token"); - - assert_eq!( - runtime - .mark_global_phase2_job_failed(ownership_token.as_str(), "lost", 3_600) - .await - .expect("mark phase2 failed with strict ownership"), - false, - "strict failure update should not match unowned running job" - ); - assert!( - runtime - .mark_global_phase2_job_failed_if_unowned(ownership_token.as_str(), "lost", 3_600) - .await - .expect("fallback failure update should match unowned running job"), - "fallback failure update should transition the unowned running job" - ); - - let claim = runtime - .try_claim_global_phase2_job(ThreadId::new(), 3_600) - .await - .expect("claim after fallback failure"); - assert_eq!(claim, Phase2JobClaimOutcome::SkippedNotDirty); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn query_logs_with_search_matches_substring() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - runtime - .insert_logs(&[ - LogEntry { - ts: 1_700_000_001, - ts_nanos: 0, - level: "INFO".to_string(), - target: "cli".to_string(), - message: Some("alpha".to_string()), - thread_id: Some("thread-1".to_string()), - process_uuid: None, - file: Some("main.rs".to_string()), - line: Some(42), - module_path: None, - }, - LogEntry { - ts: 1_700_000_002, - ts_nanos: 0, - level: "INFO".to_string(), - target: "cli".to_string(), - message: Some("alphabet".to_string()), - thread_id: Some("thread-1".to_string()), - process_uuid: None, - file: Some("main.rs".to_string()), - line: Some(43), - module_path: None, - }, - ]) - .await - .expect("insert test logs"); - - let rows = runtime - .query_logs(&LogQuery { - search: Some("alphab".to_string()), - ..Default::default() - }) - .await - .expect("query matching logs"); - - assert_eq!(rows.len(), 1); - assert_eq!(rows[0].message.as_deref(), Some("alphabet")); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn insert_logs_prunes_old_rows_when_thread_exceeds_size_limit() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let six_mebibytes = "a".repeat(6 * 1024 * 1024); - runtime - .insert_logs(&[ - LogEntry { - ts: 1, - ts_nanos: 0, - level: "INFO".to_string(), - target: "cli".to_string(), - message: Some(six_mebibytes.clone()), - thread_id: Some("thread-1".to_string()), - process_uuid: Some("proc-1".to_string()), - file: Some("main.rs".to_string()), - line: Some(1), - module_path: Some("mod".to_string()), - }, - LogEntry { - ts: 2, - ts_nanos: 0, - level: "INFO".to_string(), - target: "cli".to_string(), - message: Some(six_mebibytes.clone()), - thread_id: Some("thread-1".to_string()), - process_uuid: Some("proc-1".to_string()), - file: Some("main.rs".to_string()), - line: Some(2), - module_path: Some("mod".to_string()), - }, - ]) - .await - .expect("insert test logs"); - - let rows = runtime - .query_logs(&LogQuery { - thread_ids: vec!["thread-1".to_string()], - ..Default::default() - }) - .await - .expect("query thread logs"); - - assert_eq!(rows.len(), 1); - assert_eq!(rows[0].ts, 2); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn insert_logs_prunes_single_thread_row_when_it_exceeds_size_limit() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let eleven_mebibytes = "d".repeat(11 * 1024 * 1024); - runtime - .insert_logs(&[LogEntry { - ts: 1, - ts_nanos: 0, - level: "INFO".to_string(), - target: "cli".to_string(), - message: Some(eleven_mebibytes), - thread_id: Some("thread-oversized".to_string()), - process_uuid: Some("proc-1".to_string()), - file: Some("main.rs".to_string()), - line: Some(1), - module_path: Some("mod".to_string()), - }]) - .await - .expect("insert test log"); - - let rows = runtime - .query_logs(&LogQuery { - thread_ids: vec!["thread-oversized".to_string()], - ..Default::default() - }) - .await - .expect("query thread logs"); - - assert!(rows.is_empty()); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn insert_logs_prunes_threadless_rows_per_process_uuid_only() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let six_mebibytes = "b".repeat(6 * 1024 * 1024); - runtime - .insert_logs(&[ - LogEntry { - ts: 1, - ts_nanos: 0, - level: "INFO".to_string(), - target: "cli".to_string(), - message: Some(six_mebibytes.clone()), - thread_id: None, - process_uuid: Some("proc-1".to_string()), - file: Some("main.rs".to_string()), - line: Some(1), - module_path: Some("mod".to_string()), - }, - LogEntry { - ts: 2, - ts_nanos: 0, - level: "INFO".to_string(), - target: "cli".to_string(), - message: Some(six_mebibytes.clone()), - thread_id: None, - process_uuid: Some("proc-1".to_string()), - file: Some("main.rs".to_string()), - line: Some(2), - module_path: Some("mod".to_string()), - }, - LogEntry { - ts: 3, - ts_nanos: 0, - level: "INFO".to_string(), - target: "cli".to_string(), - message: Some(six_mebibytes), - thread_id: Some("thread-1".to_string()), - process_uuid: Some("proc-1".to_string()), - file: Some("main.rs".to_string()), - line: Some(3), - module_path: Some("mod".to_string()), - }, - ]) - .await - .expect("insert test logs"); - - let rows = runtime - .query_logs(&LogQuery { - thread_ids: vec!["thread-1".to_string()], - include_threadless: true, - ..Default::default() - }) - .await - .expect("query thread and threadless logs"); - - let mut timestamps: Vec = rows.into_iter().map(|row| row.ts).collect(); - timestamps.sort_unstable(); - assert_eq!(timestamps, vec![2, 3]); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn insert_logs_prunes_single_threadless_process_row_when_it_exceeds_size_limit() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let eleven_mebibytes = "e".repeat(11 * 1024 * 1024); - runtime - .insert_logs(&[LogEntry { - ts: 1, - ts_nanos: 0, - level: "INFO".to_string(), - target: "cli".to_string(), - message: Some(eleven_mebibytes), - thread_id: None, - process_uuid: Some("proc-oversized".to_string()), - file: Some("main.rs".to_string()), - line: Some(1), - module_path: Some("mod".to_string()), - }]) - .await - .expect("insert test log"); - - let rows = runtime - .query_logs(&LogQuery { - include_threadless: true, - ..Default::default() - }) - .await - .expect("query threadless logs"); - - assert!(rows.is_empty()); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn insert_logs_prunes_threadless_rows_with_null_process_uuid() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let six_mebibytes = "c".repeat(6 * 1024 * 1024); - runtime - .insert_logs(&[ - LogEntry { - ts: 1, - ts_nanos: 0, - level: "INFO".to_string(), - target: "cli".to_string(), - message: Some(six_mebibytes.clone()), - thread_id: None, - process_uuid: None, - file: Some("main.rs".to_string()), - line: Some(1), - module_path: Some("mod".to_string()), - }, - LogEntry { - ts: 2, - ts_nanos: 0, - level: "INFO".to_string(), - target: "cli".to_string(), - message: Some(six_mebibytes), - thread_id: None, - process_uuid: None, - file: Some("main.rs".to_string()), - line: Some(2), - module_path: Some("mod".to_string()), - }, - LogEntry { - ts: 3, - ts_nanos: 0, - level: "INFO".to_string(), - target: "cli".to_string(), - message: Some("small".to_string()), - thread_id: None, - process_uuid: Some("proc-1".to_string()), - file: Some("main.rs".to_string()), - line: Some(3), - module_path: Some("mod".to_string()), - }, - ]) - .await - .expect("insert test logs"); - - let rows = runtime - .query_logs(&LogQuery { - include_threadless: true, - ..Default::default() - }) - .await - .expect("query threadless logs"); - - let mut timestamps: Vec = rows.into_iter().map(|row| row.ts).collect(); - timestamps.sort_unstable(); - assert_eq!(timestamps, vec![2, 3]); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn insert_logs_prunes_single_threadless_null_process_row_when_it_exceeds_limit() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let eleven_mebibytes = "f".repeat(11 * 1024 * 1024); - runtime - .insert_logs(&[LogEntry { - ts: 1, - ts_nanos: 0, - level: "INFO".to_string(), - target: "cli".to_string(), - message: Some(eleven_mebibytes), - thread_id: None, - process_uuid: None, - file: Some("main.rs".to_string()), - line: Some(1), - module_path: Some("mod".to_string()), - }]) - .await - .expect("insert test log"); - - let rows = runtime - .query_logs(&LogQuery { - include_threadless: true, - ..Default::default() - }) - .await - .expect("query threadless logs"); - - assert!(rows.is_empty()); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - fn test_thread_metadata( - codex_home: &Path, - thread_id: ThreadId, - cwd: PathBuf, - ) -> ThreadMetadata { - let now = DateTime::::from_timestamp(1_700_000_000, 0).expect("timestamp"); - ThreadMetadata { - id: thread_id, - rollout_path: codex_home.join(format!("rollout-{thread_id}.jsonl")), - created_at: now, - updated_at: now, - source: "cli".to_string(), - agent_nickname: None, - agent_role: None, - model_provider: "test-provider".to_string(), - cwd, - cli_version: "0.0.0".to_string(), - title: String::new(), - sandbox_policy: crate::extract::enum_to_string(&SandboxPolicy::new_read_only_policy()), - approval_mode: crate::extract::enum_to_string(&AskForApproval::OnRequest), - tokens_used: 0, - first_user_message: Some("hello".to_string()), - archived_at: None, - git_sha: None, - git_branch: None, - git_origin_url: None, - } - } -} diff --git a/codex-rs/state/src/runtime/agent_jobs.rs b/codex-rs/state/src/runtime/agent_jobs.rs new file mode 100644 index 00000000000..c6856059457 --- /dev/null +++ b/codex-rs/state/src/runtime/agent_jobs.rs @@ -0,0 +1,562 @@ +use super::*; +use crate::model::AgentJobItemRow; + +impl StateRuntime { + pub async fn create_agent_job( + &self, + params: &AgentJobCreateParams, + items: &[AgentJobItemCreateParams], + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let input_headers_json = serde_json::to_string(¶ms.input_headers)?; + let output_schema_json = params + .output_schema_json + .as_ref() + .map(serde_json::to_string) + .transpose()?; + let max_runtime_seconds = params + .max_runtime_seconds + .map(i64::try_from) + .transpose() + .map_err(|_| anyhow::anyhow!("invalid max_runtime_seconds value"))?; + let mut tx = self.pool.begin().await?; + sqlx::query( + r#" +INSERT INTO agent_jobs ( + id, + name, + status, + instruction, + auto_export, + max_runtime_seconds, + output_schema_json, + input_headers_json, + input_csv_path, + output_csv_path, + created_at, + updated_at, + started_at, + completed_at, + last_error +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, NULL, NULL, NULL) + "#, + ) + .bind(params.id.as_str()) + .bind(params.name.as_str()) + .bind(AgentJobStatus::Pending.as_str()) + .bind(params.instruction.as_str()) + .bind(i64::from(params.auto_export)) + .bind(max_runtime_seconds) + .bind(output_schema_json) + .bind(input_headers_json) + .bind(params.input_csv_path.as_str()) + .bind(params.output_csv_path.as_str()) + .bind(now) + .bind(now) + .execute(&mut *tx) + .await?; + + for item in items { + let row_json = serde_json::to_string(&item.row_json)?; + sqlx::query( + r#" +INSERT INTO agent_job_items ( + job_id, + item_id, + row_index, + source_id, + row_json, + status, + assigned_thread_id, + attempt_count, + result_json, + last_error, + created_at, + updated_at, + completed_at, + reported_at +) VALUES (?, ?, ?, ?, ?, ?, NULL, 0, NULL, NULL, ?, ?, NULL, NULL) + "#, + ) + .bind(params.id.as_str()) + .bind(item.item_id.as_str()) + .bind(item.row_index) + .bind(item.source_id.as_deref()) + .bind(row_json) + .bind(AgentJobItemStatus::Pending.as_str()) + .bind(now) + .bind(now) + .execute(&mut *tx) + .await?; + } + + tx.commit().await?; + + let job_id = params.id.as_str(); + self.get_agent_job(job_id) + .await? + .ok_or_else(|| anyhow::anyhow!("failed to load created agent job {job_id}")) + } + + pub async fn get_agent_job(&self, job_id: &str) -> anyhow::Result> { + let row = sqlx::query_as::<_, AgentJobRow>( + r#" +SELECT + id, + name, + status, + instruction, + auto_export, + max_runtime_seconds, + output_schema_json, + input_headers_json, + input_csv_path, + output_csv_path, + created_at, + updated_at, + started_at, + completed_at, + last_error +FROM agent_jobs +WHERE id = ? + "#, + ) + .bind(job_id) + .fetch_optional(self.pool.as_ref()) + .await?; + row.map(AgentJob::try_from).transpose() + } + + pub async fn list_agent_job_items( + &self, + job_id: &str, + status: Option, + limit: Option, + ) -> anyhow::Result> { + let mut builder = QueryBuilder::::new( + r#" +SELECT + job_id, + item_id, + row_index, + source_id, + row_json, + status, + assigned_thread_id, + attempt_count, + result_json, + last_error, + created_at, + updated_at, + completed_at, + reported_at +FROM agent_job_items +WHERE job_id = + "#, + ); + builder.push_bind(job_id); + if let Some(status) = status { + builder.push(" AND status = "); + builder.push_bind(status.as_str()); + } + builder.push(" ORDER BY row_index ASC"); + if let Some(limit) = limit { + builder.push(" LIMIT "); + builder.push_bind(limit as i64); + } + let rows: Vec = builder + .build_query_as::() + .fetch_all(self.pool.as_ref()) + .await?; + rows.into_iter().map(AgentJobItem::try_from).collect() + } + + pub async fn get_agent_job_item( + &self, + job_id: &str, + item_id: &str, + ) -> anyhow::Result> { + let row: Option = sqlx::query_as::<_, AgentJobItemRow>( + r#" +SELECT + job_id, + item_id, + row_index, + source_id, + row_json, + status, + assigned_thread_id, + attempt_count, + result_json, + last_error, + created_at, + updated_at, + completed_at, + reported_at +FROM agent_job_items +WHERE job_id = ? AND item_id = ? + "#, + ) + .bind(job_id) + .bind(item_id) + .fetch_optional(self.pool.as_ref()) + .await?; + row.map(AgentJobItem::try_from).transpose() + } + + pub async fn mark_agent_job_running(&self, job_id: &str) -> anyhow::Result<()> { + let now = Utc::now().timestamp(); + sqlx::query( + r#" +UPDATE agent_jobs +SET + status = ?, + updated_at = ?, + started_at = COALESCE(started_at, ?), + completed_at = NULL, + last_error = NULL +WHERE id = ? + "#, + ) + .bind(AgentJobStatus::Running.as_str()) + .bind(now) + .bind(now) + .bind(job_id) + .execute(self.pool.as_ref()) + .await?; + Ok(()) + } + + pub async fn mark_agent_job_completed(&self, job_id: &str) -> anyhow::Result<()> { + let now = Utc::now().timestamp(); + sqlx::query( + r#" +UPDATE agent_jobs +SET status = ?, updated_at = ?, completed_at = ?, last_error = NULL +WHERE id = ? + "#, + ) + .bind(AgentJobStatus::Completed.as_str()) + .bind(now) + .bind(now) + .bind(job_id) + .execute(self.pool.as_ref()) + .await?; + Ok(()) + } + + pub async fn mark_agent_job_failed( + &self, + job_id: &str, + error_message: &str, + ) -> anyhow::Result<()> { + let now = Utc::now().timestamp(); + sqlx::query( + r#" +UPDATE agent_jobs +SET status = ?, updated_at = ?, completed_at = ?, last_error = ? +WHERE id = ? + "#, + ) + .bind(AgentJobStatus::Failed.as_str()) + .bind(now) + .bind(now) + .bind(error_message) + .bind(job_id) + .execute(self.pool.as_ref()) + .await?; + Ok(()) + } + + pub async fn mark_agent_job_cancelled( + &self, + job_id: &str, + reason: &str, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let result = sqlx::query( + r#" +UPDATE agent_jobs +SET status = ?, updated_at = ?, completed_at = ?, last_error = ? +WHERE id = ? AND status IN (?, ?) + "#, + ) + .bind(AgentJobStatus::Cancelled.as_str()) + .bind(now) + .bind(now) + .bind(reason) + .bind(job_id) + .bind(AgentJobStatus::Pending.as_str()) + .bind(AgentJobStatus::Running.as_str()) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected() > 0) + } + + pub async fn is_agent_job_cancelled(&self, job_id: &str) -> anyhow::Result { + let row = sqlx::query( + r#" +SELECT status +FROM agent_jobs +WHERE id = ? + "#, + ) + .bind(job_id) + .fetch_optional(self.pool.as_ref()) + .await?; + let Some(row) = row else { + return Ok(false); + }; + let status: String = row.try_get("status")?; + Ok(AgentJobStatus::parse(status.as_str())? == AgentJobStatus::Cancelled) + } + + pub async fn mark_agent_job_item_running( + &self, + job_id: &str, + item_id: &str, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let result = sqlx::query( + r#" +UPDATE agent_job_items +SET + status = ?, + assigned_thread_id = NULL, + attempt_count = attempt_count + 1, + updated_at = ?, + last_error = NULL +WHERE job_id = ? AND item_id = ? AND status = ? + "#, + ) + .bind(AgentJobItemStatus::Running.as_str()) + .bind(now) + .bind(job_id) + .bind(item_id) + .bind(AgentJobItemStatus::Pending.as_str()) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected() > 0) + } + + pub async fn mark_agent_job_item_running_with_thread( + &self, + job_id: &str, + item_id: &str, + thread_id: &str, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let result = sqlx::query( + r#" +UPDATE agent_job_items +SET + status = ?, + assigned_thread_id = ?, + attempt_count = attempt_count + 1, + updated_at = ?, + last_error = NULL +WHERE job_id = ? AND item_id = ? AND status = ? + "#, + ) + .bind(AgentJobItemStatus::Running.as_str()) + .bind(thread_id) + .bind(now) + .bind(job_id) + .bind(item_id) + .bind(AgentJobItemStatus::Pending.as_str()) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected() > 0) + } + + pub async fn mark_agent_job_item_pending( + &self, + job_id: &str, + item_id: &str, + error_message: Option<&str>, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let result = sqlx::query( + r#" +UPDATE agent_job_items +SET + status = ?, + assigned_thread_id = NULL, + updated_at = ?, + last_error = ? +WHERE job_id = ? AND item_id = ? AND status = ? + "#, + ) + .bind(AgentJobItemStatus::Pending.as_str()) + .bind(now) + .bind(error_message) + .bind(job_id) + .bind(item_id) + .bind(AgentJobItemStatus::Running.as_str()) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected() > 0) + } + + pub async fn set_agent_job_item_thread( + &self, + job_id: &str, + item_id: &str, + thread_id: &str, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let result = sqlx::query( + r#" +UPDATE agent_job_items +SET assigned_thread_id = ?, updated_at = ? +WHERE job_id = ? AND item_id = ? AND status = ? + "#, + ) + .bind(thread_id) + .bind(now) + .bind(job_id) + .bind(item_id) + .bind(AgentJobItemStatus::Running.as_str()) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected() > 0) + } + + pub async fn report_agent_job_item_result( + &self, + job_id: &str, + item_id: &str, + reporting_thread_id: &str, + result_json: &Value, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let serialized = serde_json::to_string(result_json)?; + let result = sqlx::query( + r#" +UPDATE agent_job_items +SET + result_json = ?, + reported_at = ?, + updated_at = ?, + last_error = NULL +WHERE + job_id = ? + AND item_id = ? + AND status = ? + AND assigned_thread_id = ? + "#, + ) + .bind(serialized) + .bind(now) + .bind(now) + .bind(job_id) + .bind(item_id) + .bind(AgentJobItemStatus::Running.as_str()) + .bind(reporting_thread_id) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected() > 0) + } + + pub async fn mark_agent_job_item_completed( + &self, + job_id: &str, + item_id: &str, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let result = sqlx::query( + r#" +UPDATE agent_job_items +SET + status = ?, + completed_at = ?, + updated_at = ?, + assigned_thread_id = NULL +WHERE + job_id = ? + AND item_id = ? + AND status = ? + AND result_json IS NOT NULL + "#, + ) + .bind(AgentJobItemStatus::Completed.as_str()) + .bind(now) + .bind(now) + .bind(job_id) + .bind(item_id) + .bind(AgentJobItemStatus::Running.as_str()) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected() > 0) + } + + pub async fn mark_agent_job_item_failed( + &self, + job_id: &str, + item_id: &str, + error_message: &str, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let result = sqlx::query( + r#" +UPDATE agent_job_items +SET + status = ?, + completed_at = ?, + updated_at = ?, + last_error = ?, + assigned_thread_id = NULL +WHERE + job_id = ? + AND item_id = ? + AND status = ? + "#, + ) + .bind(AgentJobItemStatus::Failed.as_str()) + .bind(now) + .bind(now) + .bind(error_message) + .bind(job_id) + .bind(item_id) + .bind(AgentJobItemStatus::Running.as_str()) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected() > 0) + } + + pub async fn get_agent_job_progress(&self, job_id: &str) -> anyhow::Result { + let row = sqlx::query( + r#" +SELECT + COUNT(*) AS total_items, + SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) AS pending_items, + SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) AS running_items, + SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) AS completed_items, + SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) AS failed_items +FROM agent_job_items +WHERE job_id = ? + "#, + ) + .bind(AgentJobItemStatus::Pending.as_str()) + .bind(AgentJobItemStatus::Running.as_str()) + .bind(AgentJobItemStatus::Completed.as_str()) + .bind(AgentJobItemStatus::Failed.as_str()) + .bind(job_id) + .fetch_one(self.pool.as_ref()) + .await?; + + let total_items: i64 = row.try_get("total_items")?; + let pending_items: Option = row.try_get("pending_items")?; + let running_items: Option = row.try_get("running_items")?; + let completed_items: Option = row.try_get("completed_items")?; + let failed_items: Option = row.try_get("failed_items")?; + Ok(AgentJobProgress { + total_items: usize::try_from(total_items).unwrap_or_default(), + pending_items: usize::try_from(pending_items.unwrap_or_default()).unwrap_or_default(), + running_items: usize::try_from(running_items.unwrap_or_default()).unwrap_or_default(), + completed_items: usize::try_from(completed_items.unwrap_or_default()) + .unwrap_or_default(), + failed_items: usize::try_from(failed_items.unwrap_or_default()).unwrap_or_default(), + }) + } +} diff --git a/codex-rs/state/src/runtime/backfill.rs b/codex-rs/state/src/runtime/backfill.rs new file mode 100644 index 00000000000..93e15698724 --- /dev/null +++ b/codex-rs/state/src/runtime/backfill.rs @@ -0,0 +1,311 @@ +use super::*; + +impl StateRuntime { + pub async fn get_backfill_state(&self) -> anyhow::Result { + self.ensure_backfill_state_row().await?; + let row = sqlx::query( + r#" +SELECT status, last_watermark, last_success_at +FROM backfill_state +WHERE id = 1 + "#, + ) + .fetch_one(self.pool.as_ref()) + .await?; + crate::BackfillState::try_from_row(&row) + } + + /// Attempt to claim ownership of rollout metadata backfill. + /// + /// Returns `true` when this runtime claimed the backfill worker slot. + /// Returns `false` if backfill is already complete or currently owned by a + /// non-expired worker. + pub async fn try_claim_backfill(&self, lease_seconds: i64) -> anyhow::Result { + self.ensure_backfill_state_row().await?; + let now = Utc::now().timestamp(); + let lease_cutoff = now.saturating_sub(lease_seconds.max(0)); + let result = sqlx::query( + r#" +UPDATE backfill_state +SET status = ?, updated_at = ? +WHERE id = 1 + AND status != ? + AND (status != ? OR updated_at <= ?) + "#, + ) + .bind(crate::BackfillStatus::Running.as_str()) + .bind(now) + .bind(crate::BackfillStatus::Complete.as_str()) + .bind(crate::BackfillStatus::Running.as_str()) + .bind(lease_cutoff) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected() == 1) + } + + /// Mark rollout metadata backfill as running. + pub async fn mark_backfill_running(&self) -> anyhow::Result<()> { + self.ensure_backfill_state_row().await?; + sqlx::query( + r#" +UPDATE backfill_state +SET status = ?, updated_at = ? +WHERE id = 1 + "#, + ) + .bind(crate::BackfillStatus::Running.as_str()) + .bind(Utc::now().timestamp()) + .execute(self.pool.as_ref()) + .await?; + Ok(()) + } + + /// Persist rollout metadata backfill progress. + pub async fn checkpoint_backfill(&self, watermark: &str) -> anyhow::Result<()> { + self.ensure_backfill_state_row().await?; + sqlx::query( + r#" +UPDATE backfill_state +SET status = ?, last_watermark = ?, updated_at = ? +WHERE id = 1 + "#, + ) + .bind(crate::BackfillStatus::Running.as_str()) + .bind(watermark) + .bind(Utc::now().timestamp()) + .execute(self.pool.as_ref()) + .await?; + Ok(()) + } + + /// Mark rollout metadata backfill as complete. + pub async fn mark_backfill_complete(&self, last_watermark: Option<&str>) -> anyhow::Result<()> { + self.ensure_backfill_state_row().await?; + let now = Utc::now().timestamp(); + sqlx::query( + r#" +UPDATE backfill_state +SET + status = ?, + last_watermark = COALESCE(?, last_watermark), + last_success_at = ?, + updated_at = ? +WHERE id = 1 + "#, + ) + .bind(crate::BackfillStatus::Complete.as_str()) + .bind(last_watermark) + .bind(now) + .bind(now) + .execute(self.pool.as_ref()) + .await?; + Ok(()) + } + + async fn ensure_backfill_state_row(&self) -> anyhow::Result<()> { + sqlx::query( + r#" +INSERT INTO backfill_state (id, status, last_watermark, last_success_at, updated_at) +VALUES (?, ?, NULL, NULL, ?) +ON CONFLICT(id) DO NOTHING + "#, + ) + .bind(1_i64) + .bind(crate::BackfillStatus::Pending.as_str()) + .bind(Utc::now().timestamp()) + .execute(self.pool.as_ref()) + .await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::StateRuntime; + use super::state_db_filename; + use super::test_support::unique_temp_dir; + use crate::STATE_DB_FILENAME; + use crate::STATE_DB_VERSION; + use chrono::Utc; + use pretty_assertions::assert_eq; + + #[tokio::test] + async fn init_removes_legacy_state_db_files() { + let codex_home = unique_temp_dir(); + tokio::fs::create_dir_all(&codex_home) + .await + .expect("create codex_home"); + + let current_name = state_db_filename(); + let previous_version = STATE_DB_VERSION.saturating_sub(1); + let unversioned_name = format!("{STATE_DB_FILENAME}.sqlite"); + for suffix in ["", "-wal", "-shm", "-journal"] { + let path = codex_home.join(format!("{unversioned_name}{suffix}")); + tokio::fs::write(path, b"legacy") + .await + .expect("write legacy"); + let old_version_path = codex_home.join(format!( + "{STATE_DB_FILENAME}_{previous_version}.sqlite{suffix}" + )); + tokio::fs::write(old_version_path, b"old_version") + .await + .expect("write old version"); + } + let unrelated_path = codex_home.join("state.sqlite_backup"); + tokio::fs::write(&unrelated_path, b"keep") + .await + .expect("write unrelated"); + let numeric_path = codex_home.join("123"); + tokio::fs::write(&numeric_path, b"keep") + .await + .expect("write numeric"); + + let _runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + for suffix in ["", "-wal", "-shm", "-journal"] { + let legacy_path = codex_home.join(format!("{unversioned_name}{suffix}")); + assert_eq!( + tokio::fs::try_exists(&legacy_path) + .await + .expect("check legacy path"), + false + ); + let old_version_path = codex_home.join(format!( + "{STATE_DB_FILENAME}_{previous_version}.sqlite{suffix}" + )); + assert_eq!( + tokio::fs::try_exists(&old_version_path) + .await + .expect("check old version path"), + false + ); + } + assert_eq!( + tokio::fs::try_exists(codex_home.join(current_name)) + .await + .expect("check new db path"), + true + ); + assert_eq!( + tokio::fs::try_exists(&unrelated_path) + .await + .expect("check unrelated path"), + true + ); + assert_eq!( + tokio::fs::try_exists(&numeric_path) + .await + .expect("check numeric path"), + true + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn backfill_state_persists_progress_and_completion() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let initial = runtime + .get_backfill_state() + .await + .expect("get initial backfill state"); + assert_eq!(initial.status, crate::BackfillStatus::Pending); + assert_eq!(initial.last_watermark, None); + assert_eq!(initial.last_success_at, None); + + runtime + .mark_backfill_running() + .await + .expect("mark backfill running"); + runtime + .checkpoint_backfill("sessions/2026/01/27/rollout-a.jsonl") + .await + .expect("checkpoint backfill"); + + let running = runtime + .get_backfill_state() + .await + .expect("get running backfill state"); + assert_eq!(running.status, crate::BackfillStatus::Running); + assert_eq!( + running.last_watermark, + Some("sessions/2026/01/27/rollout-a.jsonl".to_string()) + ); + assert_eq!(running.last_success_at, None); + + runtime + .mark_backfill_complete(Some("sessions/2026/01/28/rollout-b.jsonl")) + .await + .expect("mark backfill complete"); + let completed = runtime + .get_backfill_state() + .await + .expect("get completed backfill state"); + assert_eq!(completed.status, crate::BackfillStatus::Complete); + assert_eq!( + completed.last_watermark, + Some("sessions/2026/01/28/rollout-b.jsonl".to_string()) + ); + assert!(completed.last_success_at.is_some()); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn backfill_claim_is_singleton_until_stale_and_blocked_when_complete() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let claimed = runtime + .try_claim_backfill(3600) + .await + .expect("initial backfill claim"); + assert_eq!(claimed, true); + + let duplicate_claim = runtime + .try_claim_backfill(3600) + .await + .expect("duplicate backfill claim"); + assert_eq!(duplicate_claim, false); + + let stale_updated_at = Utc::now().timestamp().saturating_sub(10_000); + sqlx::query( + r#" +UPDATE backfill_state +SET status = ?, updated_at = ? +WHERE id = 1 + "#, + ) + .bind(crate::BackfillStatus::Running.as_str()) + .bind(stale_updated_at) + .execute(runtime.pool.as_ref()) + .await + .expect("force stale backfill lease"); + + let stale_claim = runtime + .try_claim_backfill(10) + .await + .expect("stale backfill claim"); + assert_eq!(stale_claim, true); + + runtime + .mark_backfill_complete(None) + .await + .expect("mark complete"); + let claim_after_complete = runtime + .try_claim_backfill(3600) + .await + .expect("claim after complete"); + assert_eq!(claim_after_complete, false); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } +} diff --git a/codex-rs/state/src/runtime/logs.rs b/codex-rs/state/src/runtime/logs.rs new file mode 100644 index 00000000000..762f68f1a9a --- /dev/null +++ b/codex-rs/state/src/runtime/logs.rs @@ -0,0 +1,715 @@ +use super::*; + +impl StateRuntime { + pub async fn insert_log(&self, entry: &LogEntry) -> anyhow::Result<()> { + self.insert_logs(std::slice::from_ref(entry)).await + } + + /// Insert a batch of log entries into the logs table. + pub async fn insert_logs(&self, entries: &[LogEntry]) -> anyhow::Result<()> { + if entries.is_empty() { + return Ok(()); + } + + let mut tx = self.pool.begin().await?; + let mut builder = QueryBuilder::::new( + "INSERT INTO logs (ts, ts_nanos, level, target, message, thread_id, process_uuid, module_path, file, line, estimated_bytes) ", + ); + builder.push_values(entries, |mut row, entry| { + let estimated_bytes = entry.message.as_ref().map_or(0, String::len) as i64 + + entry.level.len() as i64 + + entry.target.len() as i64 + + entry.module_path.as_ref().map_or(0, String::len) as i64 + + entry.file.as_ref().map_or(0, String::len) as i64; + row.push_bind(entry.ts) + .push_bind(entry.ts_nanos) + .push_bind(&entry.level) + .push_bind(&entry.target) + .push_bind(&entry.message) + .push_bind(&entry.thread_id) + .push_bind(&entry.process_uuid) + .push_bind(&entry.module_path) + .push_bind(&entry.file) + .push_bind(entry.line) + .push_bind(estimated_bytes); + }); + builder.build().execute(&mut *tx).await?; + self.prune_logs_after_insert(entries, &mut tx).await?; + tx.commit().await?; + Ok(()) + } + + /// Enforce per-partition log size caps after a successful batch insert. + /// + /// We maintain two independent budgets: + /// - Thread logs: rows with `thread_id IS NOT NULL`, capped per `thread_id`. + /// - Threadless process logs: rows with `thread_id IS NULL` ("threadless"), + /// capped per `process_uuid` (including `process_uuid IS NULL` as its own + /// threadless partition). + /// + /// "Threadless" means the log row is not associated with any conversation + /// thread, so retention is keyed by process identity instead. + /// + /// This runs inside the same transaction as the insert so callers never + /// observe "inserted but not yet pruned" rows. + async fn prune_logs_after_insert( + &self, + entries: &[LogEntry], + tx: &mut SqliteConnection, + ) -> anyhow::Result<()> { + let thread_ids: BTreeSet<&str> = entries + .iter() + .filter_map(|entry| entry.thread_id.as_deref()) + .collect(); + if !thread_ids.is_empty() { + // Cheap precheck: only run the heavier window-function prune for + // threads that are currently above the cap. + let mut over_limit_threads_query = + QueryBuilder::::new("SELECT thread_id FROM logs WHERE thread_id IN ("); + { + let mut separated = over_limit_threads_query.separated(", "); + for thread_id in &thread_ids { + separated.push_bind(*thread_id); + } + } + over_limit_threads_query.push(") GROUP BY thread_id HAVING SUM("); + over_limit_threads_query.push("estimated_bytes"); + over_limit_threads_query.push(") > "); + over_limit_threads_query.push_bind(LOG_PARTITION_SIZE_LIMIT_BYTES); + let over_limit_thread_ids: Vec = over_limit_threads_query + .build() + .fetch_all(&mut *tx) + .await? + .into_iter() + .map(|row| row.try_get("thread_id")) + .collect::>()?; + if !over_limit_thread_ids.is_empty() { + // Enforce a strict per-thread cap by deleting every row whose + // newest-first cumulative bytes exceed the partition budget. + let mut prune_threads = QueryBuilder::::new( + r#" +DELETE FROM logs +WHERE id IN ( + SELECT id + FROM ( + SELECT + id, + SUM( +"#, + ); + prune_threads.push("estimated_bytes"); + prune_threads.push( + r#" + ) OVER ( + PARTITION BY thread_id + ORDER BY ts DESC, ts_nanos DESC, id DESC + ) AS cumulative_bytes + FROM logs + WHERE thread_id IN ( +"#, + ); + { + let mut separated = prune_threads.separated(", "); + for thread_id in &over_limit_thread_ids { + separated.push_bind(thread_id); + } + } + prune_threads.push( + r#" + ) + ) + WHERE cumulative_bytes > +"#, + ); + prune_threads.push_bind(LOG_PARTITION_SIZE_LIMIT_BYTES); + prune_threads.push("\n)"); + prune_threads.build().execute(&mut *tx).await?; + } + } + + let threadless_process_uuids: BTreeSet<&str> = entries + .iter() + .filter(|entry| entry.thread_id.is_none()) + .filter_map(|entry| entry.process_uuid.as_deref()) + .collect(); + let has_threadless_null_process_uuid = entries + .iter() + .any(|entry| entry.thread_id.is_none() && entry.process_uuid.is_none()); + if !threadless_process_uuids.is_empty() { + // Threadless logs are budgeted separately per process UUID. + let mut over_limit_processes_query = QueryBuilder::::new( + "SELECT process_uuid FROM logs WHERE thread_id IS NULL AND process_uuid IN (", + ); + { + let mut separated = over_limit_processes_query.separated(", "); + for process_uuid in &threadless_process_uuids { + separated.push_bind(*process_uuid); + } + } + over_limit_processes_query.push(") GROUP BY process_uuid HAVING SUM("); + over_limit_processes_query.push("estimated_bytes"); + over_limit_processes_query.push(") > "); + over_limit_processes_query.push_bind(LOG_PARTITION_SIZE_LIMIT_BYTES); + let over_limit_process_uuids: Vec = over_limit_processes_query + .build() + .fetch_all(&mut *tx) + .await? + .into_iter() + .map(|row| row.try_get("process_uuid")) + .collect::>()?; + if !over_limit_process_uuids.is_empty() { + // Same strict cap policy as thread pruning, but only for + // threadless rows in the affected process UUIDs. + let mut prune_threadless_process_logs = QueryBuilder::::new( + r#" +DELETE FROM logs +WHERE id IN ( + SELECT id + FROM ( + SELECT + id, + SUM( +"#, + ); + prune_threadless_process_logs.push("estimated_bytes"); + prune_threadless_process_logs.push( + r#" + ) OVER ( + PARTITION BY process_uuid + ORDER BY ts DESC, ts_nanos DESC, id DESC + ) AS cumulative_bytes + FROM logs + WHERE thread_id IS NULL + AND process_uuid IN ( +"#, + ); + { + let mut separated = prune_threadless_process_logs.separated(", "); + for process_uuid in &over_limit_process_uuids { + separated.push_bind(process_uuid); + } + } + prune_threadless_process_logs.push( + r#" + ) + ) + WHERE cumulative_bytes > +"#, + ); + prune_threadless_process_logs.push_bind(LOG_PARTITION_SIZE_LIMIT_BYTES); + prune_threadless_process_logs.push("\n)"); + prune_threadless_process_logs + .build() + .execute(&mut *tx) + .await?; + } + } + if has_threadless_null_process_uuid { + // Rows without a process UUID still need a cap; treat NULL as its + // own threadless partition. + let mut null_process_usage_query = QueryBuilder::::new("SELECT SUM("); + null_process_usage_query.push("estimated_bytes"); + null_process_usage_query.push( + ") AS total_bytes FROM logs WHERE thread_id IS NULL AND process_uuid IS NULL", + ); + let total_null_process_bytes: Option = null_process_usage_query + .build() + .fetch_one(&mut *tx) + .await? + .try_get("total_bytes")?; + + if total_null_process_bytes.unwrap_or(0) > LOG_PARTITION_SIZE_LIMIT_BYTES { + let mut prune_threadless_null_process_logs = QueryBuilder::::new( + r#" +DELETE FROM logs +WHERE id IN ( + SELECT id + FROM ( + SELECT + id, + SUM( +"#, + ); + prune_threadless_null_process_logs.push("estimated_bytes"); + prune_threadless_null_process_logs.push( + r#" + ) OVER ( + PARTITION BY process_uuid + ORDER BY ts DESC, ts_nanos DESC, id DESC + ) AS cumulative_bytes + FROM logs + WHERE thread_id IS NULL + AND process_uuid IS NULL + ) + WHERE cumulative_bytes > +"#, + ); + prune_threadless_null_process_logs.push_bind(LOG_PARTITION_SIZE_LIMIT_BYTES); + prune_threadless_null_process_logs.push("\n)"); + prune_threadless_null_process_logs + .build() + .execute(&mut *tx) + .await?; + } + } + Ok(()) + } + + pub(crate) async fn delete_logs_before(&self, cutoff_ts: i64) -> anyhow::Result { + let result = sqlx::query("DELETE FROM logs WHERE ts < ?") + .bind(cutoff_ts) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected()) + } + + /// Query logs with optional filters. + pub async fn query_logs(&self, query: &LogQuery) -> anyhow::Result> { + let mut builder = QueryBuilder::::new( + "SELECT id, ts, ts_nanos, level, target, message, thread_id, process_uuid, file, line FROM logs WHERE 1 = 1", + ); + push_log_filters(&mut builder, query); + if query.descending { + builder.push(" ORDER BY id DESC"); + } else { + builder.push(" ORDER BY id ASC"); + } + if let Some(limit) = query.limit { + builder.push(" LIMIT ").push_bind(limit as i64); + } + + let rows = builder + .build_query_as::() + .fetch_all(self.pool.as_ref()) + .await?; + Ok(rows) + } + + /// Return the max log id matching optional filters. + pub async fn max_log_id(&self, query: &LogQuery) -> anyhow::Result { + let mut builder = + QueryBuilder::::new("SELECT MAX(id) AS max_id FROM logs WHERE 1 = 1"); + push_log_filters(&mut builder, query); + let row = builder.build().fetch_one(self.pool.as_ref()).await?; + let max_id: Option = row.try_get("max_id")?; + Ok(max_id.unwrap_or(0)) + } +} + +fn push_log_filters<'a>(builder: &mut QueryBuilder<'a, Sqlite>, query: &'a LogQuery) { + if let Some(level_upper) = query.level_upper.as_ref() { + builder + .push(" AND UPPER(level) = ") + .push_bind(level_upper.as_str()); + } + if let Some(from_ts) = query.from_ts { + builder.push(" AND ts >= ").push_bind(from_ts); + } + if let Some(to_ts) = query.to_ts { + builder.push(" AND ts <= ").push_bind(to_ts); + } + push_like_filters(builder, "module_path", &query.module_like); + push_like_filters(builder, "file", &query.file_like); + let has_thread_filter = !query.thread_ids.is_empty() || query.include_threadless; + if has_thread_filter { + builder.push(" AND ("); + let mut needs_or = false; + for thread_id in &query.thread_ids { + if needs_or { + builder.push(" OR "); + } + builder.push("thread_id = ").push_bind(thread_id.as_str()); + needs_or = true; + } + if query.include_threadless { + if needs_or { + builder.push(" OR "); + } + builder.push("thread_id IS NULL"); + } + builder.push(")"); + } + if let Some(after_id) = query.after_id { + builder.push(" AND id > ").push_bind(after_id); + } + if let Some(search) = query.search.as_ref() { + builder.push(" AND INSTR(message, "); + builder.push_bind(search.as_str()); + builder.push(") > 0"); + } +} + +fn push_like_filters<'a>( + builder: &mut QueryBuilder<'a, Sqlite>, + column: &str, + filters: &'a [String], +) { + if filters.is_empty() { + return; + } + builder.push(" AND ("); + for (idx, filter) in filters.iter().enumerate() { + if idx > 0 { + builder.push(" OR "); + } + builder + .push(column) + .push(" LIKE '%' || ") + .push_bind(filter.as_str()) + .push(" || '%'"); + } + builder.push(")"); +} + +#[cfg(test)] +mod tests { + use super::StateRuntime; + use super::test_support::unique_temp_dir; + use crate::LogEntry; + use crate::LogQuery; + use pretty_assertions::assert_eq; + #[tokio::test] + async fn query_logs_with_search_matches_substring() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + runtime + .insert_logs(&[ + LogEntry { + ts: 1_700_000_001, + ts_nanos: 0, + level: "INFO".to_string(), + target: "cli".to_string(), + message: Some("alpha".to_string()), + thread_id: Some("thread-1".to_string()), + process_uuid: None, + file: Some("main.rs".to_string()), + line: Some(42), + module_path: None, + }, + LogEntry { + ts: 1_700_000_002, + ts_nanos: 0, + level: "INFO".to_string(), + target: "cli".to_string(), + message: Some("alphabet".to_string()), + thread_id: Some("thread-1".to_string()), + process_uuid: None, + file: Some("main.rs".to_string()), + line: Some(43), + module_path: None, + }, + ]) + .await + .expect("insert test logs"); + + let rows = runtime + .query_logs(&LogQuery { + search: Some("alphab".to_string()), + ..Default::default() + }) + .await + .expect("query matching logs"); + + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].message.as_deref(), Some("alphabet")); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn insert_logs_prunes_old_rows_when_thread_exceeds_size_limit() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let six_mebibytes = "a".repeat(6 * 1024 * 1024); + runtime + .insert_logs(&[ + LogEntry { + ts: 1, + ts_nanos: 0, + level: "INFO".to_string(), + target: "cli".to_string(), + message: Some(six_mebibytes.clone()), + thread_id: Some("thread-1".to_string()), + process_uuid: Some("proc-1".to_string()), + file: Some("main.rs".to_string()), + line: Some(1), + module_path: Some("mod".to_string()), + }, + LogEntry { + ts: 2, + ts_nanos: 0, + level: "INFO".to_string(), + target: "cli".to_string(), + message: Some(six_mebibytes.clone()), + thread_id: Some("thread-1".to_string()), + process_uuid: Some("proc-1".to_string()), + file: Some("main.rs".to_string()), + line: Some(2), + module_path: Some("mod".to_string()), + }, + ]) + .await + .expect("insert test logs"); + + let rows = runtime + .query_logs(&LogQuery { + thread_ids: vec!["thread-1".to_string()], + ..Default::default() + }) + .await + .expect("query thread logs"); + + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].ts, 2); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn insert_logs_prunes_single_thread_row_when_it_exceeds_size_limit() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let eleven_mebibytes = "d".repeat(11 * 1024 * 1024); + runtime + .insert_logs(&[LogEntry { + ts: 1, + ts_nanos: 0, + level: "INFO".to_string(), + target: "cli".to_string(), + message: Some(eleven_mebibytes), + thread_id: Some("thread-oversized".to_string()), + process_uuid: Some("proc-1".to_string()), + file: Some("main.rs".to_string()), + line: Some(1), + module_path: Some("mod".to_string()), + }]) + .await + .expect("insert test log"); + + let rows = runtime + .query_logs(&LogQuery { + thread_ids: vec!["thread-oversized".to_string()], + ..Default::default() + }) + .await + .expect("query thread logs"); + + assert!(rows.is_empty()); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn insert_logs_prunes_threadless_rows_per_process_uuid_only() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let six_mebibytes = "b".repeat(6 * 1024 * 1024); + runtime + .insert_logs(&[ + LogEntry { + ts: 1, + ts_nanos: 0, + level: "INFO".to_string(), + target: "cli".to_string(), + message: Some(six_mebibytes.clone()), + thread_id: None, + process_uuid: Some("proc-1".to_string()), + file: Some("main.rs".to_string()), + line: Some(1), + module_path: Some("mod".to_string()), + }, + LogEntry { + ts: 2, + ts_nanos: 0, + level: "INFO".to_string(), + target: "cli".to_string(), + message: Some(six_mebibytes.clone()), + thread_id: None, + process_uuid: Some("proc-1".to_string()), + file: Some("main.rs".to_string()), + line: Some(2), + module_path: Some("mod".to_string()), + }, + LogEntry { + ts: 3, + ts_nanos: 0, + level: "INFO".to_string(), + target: "cli".to_string(), + message: Some(six_mebibytes), + thread_id: Some("thread-1".to_string()), + process_uuid: Some("proc-1".to_string()), + file: Some("main.rs".to_string()), + line: Some(3), + module_path: Some("mod".to_string()), + }, + ]) + .await + .expect("insert test logs"); + + let rows = runtime + .query_logs(&LogQuery { + thread_ids: vec!["thread-1".to_string()], + include_threadless: true, + ..Default::default() + }) + .await + .expect("query thread and threadless logs"); + + let mut timestamps: Vec = rows.into_iter().map(|row| row.ts).collect(); + timestamps.sort_unstable(); + assert_eq!(timestamps, vec![2, 3]); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn insert_logs_prunes_single_threadless_process_row_when_it_exceeds_size_limit() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let eleven_mebibytes = "e".repeat(11 * 1024 * 1024); + runtime + .insert_logs(&[LogEntry { + ts: 1, + ts_nanos: 0, + level: "INFO".to_string(), + target: "cli".to_string(), + message: Some(eleven_mebibytes), + thread_id: None, + process_uuid: Some("proc-oversized".to_string()), + file: Some("main.rs".to_string()), + line: Some(1), + module_path: Some("mod".to_string()), + }]) + .await + .expect("insert test log"); + + let rows = runtime + .query_logs(&LogQuery { + include_threadless: true, + ..Default::default() + }) + .await + .expect("query threadless logs"); + + assert!(rows.is_empty()); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn insert_logs_prunes_threadless_rows_with_null_process_uuid() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let six_mebibytes = "c".repeat(6 * 1024 * 1024); + runtime + .insert_logs(&[ + LogEntry { + ts: 1, + ts_nanos: 0, + level: "INFO".to_string(), + target: "cli".to_string(), + message: Some(six_mebibytes.clone()), + thread_id: None, + process_uuid: None, + file: Some("main.rs".to_string()), + line: Some(1), + module_path: Some("mod".to_string()), + }, + LogEntry { + ts: 2, + ts_nanos: 0, + level: "INFO".to_string(), + target: "cli".to_string(), + message: Some(six_mebibytes), + thread_id: None, + process_uuid: None, + file: Some("main.rs".to_string()), + line: Some(2), + module_path: Some("mod".to_string()), + }, + LogEntry { + ts: 3, + ts_nanos: 0, + level: "INFO".to_string(), + target: "cli".to_string(), + message: Some("small".to_string()), + thread_id: None, + process_uuid: Some("proc-1".to_string()), + file: Some("main.rs".to_string()), + line: Some(3), + module_path: Some("mod".to_string()), + }, + ]) + .await + .expect("insert test logs"); + + let rows = runtime + .query_logs(&LogQuery { + include_threadless: true, + ..Default::default() + }) + .await + .expect("query threadless logs"); + + let mut timestamps: Vec = rows.into_iter().map(|row| row.ts).collect(); + timestamps.sort_unstable(); + assert_eq!(timestamps, vec![2, 3]); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn insert_logs_prunes_single_threadless_null_process_row_when_it_exceeds_limit() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let eleven_mebibytes = "f".repeat(11 * 1024 * 1024); + runtime + .insert_logs(&[LogEntry { + ts: 1, + ts_nanos: 0, + level: "INFO".to_string(), + target: "cli".to_string(), + message: Some(eleven_mebibytes), + thread_id: None, + process_uuid: None, + file: Some("main.rs".to_string()), + line: Some(1), + module_path: Some("mod".to_string()), + }]) + .await + .expect("insert test log"); + + let rows = runtime + .query_logs(&LogQuery { + include_threadless: true, + ..Default::default() + }) + .await + .expect("query threadless logs"); + + assert!(rows.is_empty()); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } +} diff --git a/codex-rs/state/src/runtime/memories.rs b/codex-rs/state/src/runtime/memories.rs index 859ad0d14d3..5ebb2be88ff 100644 --- a/codex-rs/state/src/runtime/memories.rs +++ b/codex-rs/state/src/runtime/memories.rs @@ -1,3 +1,5 @@ +use super::threads::push_thread_filters; +use super::threads::push_thread_order_and_limit; use super::*; use crate::model::Phase2InputSelection; use crate::model::Phase2JobClaimOutcome; @@ -13,6 +15,7 @@ use sqlx::Executor; use sqlx::QueryBuilder; use sqlx::Sqlite; use std::collections::HashSet; +use uuid::Uuid; const JOB_KIND_MEMORY_STAGE1: &str = "memory_stage1"; const JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL: &str = "memory_consolidate_global"; @@ -221,7 +224,7 @@ LEFT JOIN jobs /// /// Query behavior: /// - filters out rows where both `raw_memory` and `rollout_summary` are blank - /// - joins `threads` to include thread `cwd` and `rollout_path` + /// - joins `threads` to include thread `cwd`, `rollout_path`, and `git_branch` /// - orders by `source_updated_at DESC, thread_id DESC` /// - applies `LIMIT n` pub async fn list_stage1_outputs_for_global( @@ -241,8 +244,9 @@ SELECT so.raw_memory, so.rollout_summary, so.rollout_slug, - so.generated_at - , COALESCE(t.cwd, '') AS cwd + so.generated_at, + COALESCE(t.cwd, '') AS cwd, + t.git_branch AS git_branch FROM stage1_outputs AS so LEFT JOIN threads AS t ON t.id = so.thread_id @@ -264,8 +268,13 @@ LIMIT ? /// last successful phase-2 selection. /// /// Query behavior: - /// - current selection is the latest `n` non-empty stage-1 outputs ordered - /// by `source_updated_at DESC, thread_id DESC` + /// - current selection keeps only non-empty stage-1 outputs whose + /// `last_usage` is within `max_unused_days`, or whose + /// `source_updated_at` is within that window when the memory has never + /// been used + /// - eligible rows are ordered by `usage_count DESC`, + /// `COALESCE(last_usage, source_updated_at) DESC`, `source_updated_at DESC`, + /// `thread_id DESC` /// - previously selected rows are identified by `selected_for_phase2 = 1` /// - `previous_selected` contains the current persisted rows that belonged /// to the last successful phase-2 baseline @@ -276,10 +285,12 @@ LIMIT ? pub async fn get_phase2_input_selection( &self, n: usize, + max_unused_days: i64, ) -> anyhow::Result { if n == 0 { return Ok(Phase2InputSelection::default()); } + let cutoff = (Utc::now() - Duration::days(max_unused_days.max(0))).timestamp(); let current_rows = sqlx::query( r#" @@ -291,17 +302,28 @@ SELECT so.rollout_summary, so.rollout_slug, so.generated_at, + COALESCE(t.cwd, '') AS cwd, + t.git_branch AS git_branch, so.selected_for_phase2, - so.selected_for_phase2_source_updated_at, - COALESCE(t.cwd, '') AS cwd + so.selected_for_phase2_source_updated_at FROM stage1_outputs AS so LEFT JOIN threads AS t ON t.id = so.thread_id -WHERE length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0 -ORDER BY so.source_updated_at DESC, so.thread_id DESC +WHERE (length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0) + AND ( + (so.last_usage IS NOT NULL AND so.last_usage >= ?) + OR (so.last_usage IS NULL AND so.source_updated_at >= ?) + ) +ORDER BY + COALESCE(so.usage_count, 0) DESC, + COALESCE(so.last_usage, so.source_updated_at) DESC, + so.source_updated_at DESC, + so.thread_id DESC LIMIT ? "#, ) + .bind(cutoff) + .bind(cutoff) .bind(n as i64) .fetch_all(self.pool.as_ref()) .await?; @@ -332,9 +354,10 @@ SELECT so.source_updated_at, so.raw_memory, so.rollout_summary, - so.rollout_slug - , so.generated_at - , COALESCE(t.cwd, '') AS cwd + so.rollout_slug, + so.generated_at, + COALESCE(t.cwd, '') AS cwd, + t.git_branch AS git_branch FROM stage1_outputs AS so LEFT JOIN threads AS t ON t.id = so.thread_id @@ -1130,3 +1153,2515 @@ ON CONFLICT(kind, job_key) DO UPDATE SET Ok(()) } + +#[cfg(test)] +mod tests { + use super::StateRuntime; + use super::test_support::test_thread_metadata; + use super::test_support::unique_temp_dir; + use crate::model::Phase2JobClaimOutcome; + use crate::model::Stage1JobClaimOutcome; + use crate::model::Stage1StartupClaimParams; + use chrono::Duration; + use chrono::Utc; + use codex_protocol::ThreadId; + use pretty_assertions::assert_eq; + use sqlx::Row; + use std::sync::Arc; + use uuid::Uuid; + + #[tokio::test] + async fn stage1_claim_skips_when_up_to_date() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let metadata = test_thread_metadata(&codex_home, thread_id, codex_home.join("a")); + runtime + .upsert_thread(&metadata) + .await + .expect("upsert thread"); + + let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + + let claim = runtime + .try_claim_stage1_job(thread_id, owner_a, 100, 3600, 64) + .await + .expect("claim stage1 job"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected claim outcome: {other:?}"), + }; + + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + ownership_token.as_str(), + 100, + "raw", + "sum", + None, + ) + .await + .expect("mark stage1 succeeded"), + "stage1 success should finalize for current token" + ); + + let up_to_date = runtime + .try_claim_stage1_job(thread_id, owner_b, 100, 3600, 64) + .await + .expect("claim stage1 up-to-date"); + assert_eq!(up_to_date, Stage1JobClaimOutcome::SkippedUpToDate); + + let needs_rerun = runtime + .try_claim_stage1_job(thread_id, owner_b, 101, 3600, 64) + .await + .expect("claim stage1 newer source"); + assert!( + matches!(needs_rerun, Stage1JobClaimOutcome::Claimed { .. }), + "newer source_updated_at should be claimable" + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn stage1_running_stale_can_be_stolen_but_fresh_running_is_skipped() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let cwd = codex_home.join("workspace"); + runtime + .upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd)) + .await + .expect("upsert thread"); + + let claim_a = runtime + .try_claim_stage1_job(thread_id, owner_a, 100, 3600, 64) + .await + .expect("claim a"); + assert!(matches!(claim_a, Stage1JobClaimOutcome::Claimed { .. })); + + let claim_b_fresh = runtime + .try_claim_stage1_job(thread_id, owner_b, 100, 3600, 64) + .await + .expect("claim b fresh"); + assert_eq!(claim_b_fresh, Stage1JobClaimOutcome::SkippedRunning); + + sqlx::query("UPDATE jobs SET lease_until = 0 WHERE kind = 'memory_stage1' AND job_key = ?") + .bind(thread_id.to_string()) + .execute(runtime.pool.as_ref()) + .await + .expect("force stale lease"); + + let claim_b_stale = runtime + .try_claim_stage1_job(thread_id, owner_b, 100, 3600, 64) + .await + .expect("claim b stale"); + assert!(matches!( + claim_b_stale, + Stage1JobClaimOutcome::Claimed { .. } + )); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn stage1_concurrent_claim_for_same_thread_is_conflict_safe() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join("workspace"), + )) + .await + .expect("upsert thread"); + + let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let thread_id_a = thread_id; + let thread_id_b = thread_id; + let runtime_a = Arc::clone(&runtime); + let runtime_b = Arc::clone(&runtime); + let claim_with_retry = |runtime: Arc, + thread_id: ThreadId, + owner: ThreadId| async move { + for attempt in 0..5 { + match runtime + .try_claim_stage1_job(thread_id, owner, 100, 3_600, 64) + .await + { + Ok(outcome) => return outcome, + Err(err) if err.to_string().contains("database is locked") && attempt < 4 => { + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + Err(err) => panic!("claim stage1 should not fail: {err}"), + } + } + panic!("claim stage1 should have returned within retry budget") + }; + + let (claim_a, claim_b) = tokio::join!( + claim_with_retry(runtime_a, thread_id_a, owner_a), + claim_with_retry(runtime_b, thread_id_b, owner_b), + ); + + let claim_outcomes = vec![claim_a, claim_b]; + let claimed_count = claim_outcomes + .iter() + .filter(|outcome| matches!(outcome, Stage1JobClaimOutcome::Claimed { .. })) + .count(); + assert_eq!(claimed_count, 1); + assert!( + claim_outcomes.iter().all(|outcome| { + matches!( + outcome, + Stage1JobClaimOutcome::Claimed { .. } | Stage1JobClaimOutcome::SkippedRunning + ) + }), + "unexpected claim outcomes: {claim_outcomes:?}" + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn stage1_concurrent_claims_respect_running_cap() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let thread_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_a, + codex_home.join("workspace-a"), + )) + .await + .expect("upsert thread a"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_b, + codex_home.join("workspace-b"), + )) + .await + .expect("upsert thread b"); + + let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let runtime_a = Arc::clone(&runtime); + let runtime_b = Arc::clone(&runtime); + + let (claim_a, claim_b) = tokio::join!( + async move { + runtime_a + .try_claim_stage1_job(thread_a, owner_a, 100, 3_600, 1) + .await + .expect("claim stage1 thread a") + }, + async move { + runtime_b + .try_claim_stage1_job(thread_b, owner_b, 101, 3_600, 1) + .await + .expect("claim stage1 thread b") + }, + ); + + let claim_outcomes = vec![claim_a, claim_b]; + let claimed_count = claim_outcomes + .iter() + .filter(|outcome| matches!(outcome, Stage1JobClaimOutcome::Claimed { .. })) + .count(); + assert_eq!(claimed_count, 1); + assert!( + claim_outcomes + .iter() + .any(|outcome| { matches!(outcome, Stage1JobClaimOutcome::SkippedRunning) }), + "one concurrent claim should be throttled by running cap: {claim_outcomes:?}" + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn claim_stage1_jobs_filters_by_age_idle_and_current_thread() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let now = Utc::now(); + let fresh_at = now - Duration::hours(1); + let just_under_idle_at = now - Duration::hours(12) + Duration::minutes(1); + let eligible_idle_at = now - Duration::hours(12) - Duration::minutes(1); + let old_at = now - Duration::days(31); + + let current_thread_id = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("current thread id"); + let fresh_thread_id = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("fresh thread id"); + let just_under_idle_thread_id = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("just under idle thread id"); + let eligible_idle_thread_id = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("eligible idle thread id"); + let old_thread_id = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("old thread id"); + + let mut current = + test_thread_metadata(&codex_home, current_thread_id, codex_home.join("current")); + current.created_at = now; + current.updated_at = now; + runtime + .upsert_thread(¤t) + .await + .expect("upsert current"); + + let mut fresh = + test_thread_metadata(&codex_home, fresh_thread_id, codex_home.join("fresh")); + fresh.created_at = fresh_at; + fresh.updated_at = fresh_at; + runtime.upsert_thread(&fresh).await.expect("upsert fresh"); + + let mut just_under_idle = test_thread_metadata( + &codex_home, + just_under_idle_thread_id, + codex_home.join("just-under-idle"), + ); + just_under_idle.created_at = just_under_idle_at; + just_under_idle.updated_at = just_under_idle_at; + runtime + .upsert_thread(&just_under_idle) + .await + .expect("upsert just-under-idle"); + + let mut eligible_idle = test_thread_metadata( + &codex_home, + eligible_idle_thread_id, + codex_home.join("eligible-idle"), + ); + eligible_idle.created_at = eligible_idle_at; + eligible_idle.updated_at = eligible_idle_at; + runtime + .upsert_thread(&eligible_idle) + .await + .expect("upsert eligible-idle"); + + let mut old = test_thread_metadata(&codex_home, old_thread_id, codex_home.join("old")); + old.created_at = old_at; + old.updated_at = old_at; + runtime.upsert_thread(&old).await.expect("upsert old"); + + let allowed_sources = vec!["cli".to_string()]; + let claims = runtime + .claim_stage1_jobs_for_startup( + current_thread_id, + Stage1StartupClaimParams { + scan_limit: 1, + max_claimed: 5, + max_age_days: 30, + min_rollout_idle_hours: 12, + allowed_sources: allowed_sources.as_slice(), + lease_seconds: 3600, + }, + ) + .await + .expect("claim stage1 jobs"); + + assert_eq!(claims.len(), 1); + assert_eq!(claims[0].thread.id, eligible_idle_thread_id); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn claim_stage1_jobs_prefilters_threads_with_up_to_date_memory() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let now = Utc::now(); + let eligible_newer_at = now - Duration::hours(13); + let eligible_older_at = now - Duration::hours(14); + + let current_thread_id = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("current thread id"); + let up_to_date_thread_id = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("up-to-date thread id"); + let stale_thread_id = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("stale thread id"); + let worker_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("worker id"); + + let mut current = + test_thread_metadata(&codex_home, current_thread_id, codex_home.join("current")); + current.created_at = now; + current.updated_at = now; + runtime + .upsert_thread(¤t) + .await + .expect("upsert current thread"); + + let mut up_to_date = test_thread_metadata( + &codex_home, + up_to_date_thread_id, + codex_home.join("up-to-date"), + ); + up_to_date.created_at = eligible_newer_at; + up_to_date.updated_at = eligible_newer_at; + runtime + .upsert_thread(&up_to_date) + .await + .expect("upsert up-to-date thread"); + + let up_to_date_claim = runtime + .try_claim_stage1_job( + up_to_date_thread_id, + worker_id, + up_to_date.updated_at.timestamp(), + 3600, + 64, + ) + .await + .expect("claim up-to-date thread for seed"); + let up_to_date_token = match up_to_date_claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected seed claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + up_to_date_thread_id, + up_to_date_token.as_str(), + up_to_date.updated_at.timestamp(), + "raw", + "summary", + None, + ) + .await + .expect("mark up-to-date thread succeeded"), + "seed stage1 success should complete for up-to-date thread" + ); + + let mut stale = + test_thread_metadata(&codex_home, stale_thread_id, codex_home.join("stale")); + stale.created_at = eligible_older_at; + stale.updated_at = eligible_older_at; + runtime + .upsert_thread(&stale) + .await + .expect("upsert stale thread"); + + let allowed_sources = vec!["cli".to_string()]; + let claims = runtime + .claim_stage1_jobs_for_startup( + current_thread_id, + Stage1StartupClaimParams { + scan_limit: 1, + max_claimed: 1, + max_age_days: 30, + min_rollout_idle_hours: 12, + allowed_sources: allowed_sources.as_slice(), + lease_seconds: 3600, + }, + ) + .await + .expect("claim stage1 startup jobs"); + assert_eq!(claims.len(), 1); + assert_eq!(claims[0].thread.id, stale_thread_id); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn claim_stage1_jobs_enforces_global_running_cap() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let current_thread_id = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("current thread id"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + current_thread_id, + codex_home.join("current"), + )) + .await + .expect("upsert current"); + + let now = Utc::now(); + let started_at = now.timestamp(); + let lease_until = started_at + 3600; + let eligible_at = now - Duration::hours(13); + let existing_running = 10usize; + let total_candidates = 80usize; + + for idx in 0..total_candidates { + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let mut metadata = test_thread_metadata( + &codex_home, + thread_id, + codex_home.join(format!("thread-{idx}")), + ); + metadata.created_at = eligible_at - Duration::seconds(idx as i64); + metadata.updated_at = eligible_at - Duration::seconds(idx as i64); + runtime + .upsert_thread(&metadata) + .await + .expect("upsert thread"); + + if idx < existing_running { + sqlx::query( + r#" +INSERT INTO jobs ( + kind, + job_key, + status, + worker_id, + ownership_token, + started_at, + finished_at, + lease_until, + retry_at, + retry_remaining, + last_error, + input_watermark, + last_success_watermark +) VALUES (?, ?, 'running', ?, ?, ?, NULL, ?, NULL, ?, NULL, ?, NULL) + "#, + ) + .bind("memory_stage1") + .bind(thread_id.to_string()) + .bind(current_thread_id.to_string()) + .bind(Uuid::new_v4().to_string()) + .bind(started_at) + .bind(lease_until) + .bind(3) + .bind(metadata.updated_at.timestamp()) + .execute(runtime.pool.as_ref()) + .await + .expect("seed running stage1 job"); + } + } + + let allowed_sources = vec!["cli".to_string()]; + let claims = runtime + .claim_stage1_jobs_for_startup( + current_thread_id, + Stage1StartupClaimParams { + scan_limit: 200, + max_claimed: 64, + max_age_days: 30, + min_rollout_idle_hours: 12, + allowed_sources: allowed_sources.as_slice(), + lease_seconds: 3600, + }, + ) + .await + .expect("claim stage1 jobs"); + assert_eq!(claims.len(), 54); + + let running_count = sqlx::query( + r#" +SELECT COUNT(*) AS count +FROM jobs +WHERE kind = 'memory_stage1' + AND status = 'running' + AND lease_until IS NOT NULL + AND lease_until > ? + "#, + ) + .bind(Utc::now().timestamp()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("count running stage1 jobs") + .try_get::("count") + .expect("running count value"); + assert_eq!(running_count, 64); + + let more_claims = runtime + .claim_stage1_jobs_for_startup( + current_thread_id, + Stage1StartupClaimParams { + scan_limit: 200, + max_claimed: 64, + max_age_days: 30, + min_rollout_idle_hours: 12, + allowed_sources: allowed_sources.as_slice(), + lease_seconds: 3600, + }, + ) + .await + .expect("claim stage1 jobs with cap reached"); + assert_eq!(more_claims.len(), 0); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn claim_stage1_jobs_processes_two_full_batches_across_startup_passes() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let current_thread_id = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("current thread id"); + let mut current = + test_thread_metadata(&codex_home, current_thread_id, codex_home.join("current")); + current.created_at = Utc::now(); + current.updated_at = Utc::now(); + runtime + .upsert_thread(¤t) + .await + .expect("upsert current"); + + let eligible_at = Utc::now() - Duration::hours(13); + for idx in 0..200 { + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let mut metadata = test_thread_metadata( + &codex_home, + thread_id, + codex_home.join(format!("thread-{idx}")), + ); + metadata.created_at = eligible_at - Duration::seconds(idx as i64); + metadata.updated_at = eligible_at - Duration::seconds(idx as i64); + runtime + .upsert_thread(&metadata) + .await + .expect("upsert eligible thread"); + } + + let allowed_sources = vec!["cli".to_string()]; + let first_claims = runtime + .claim_stage1_jobs_for_startup( + current_thread_id, + Stage1StartupClaimParams { + scan_limit: 5_000, + max_claimed: 64, + max_age_days: 30, + min_rollout_idle_hours: 12, + allowed_sources: allowed_sources.as_slice(), + lease_seconds: 3_600, + }, + ) + .await + .expect("first stage1 startup claim"); + assert_eq!(first_claims.len(), 64); + + for claim in first_claims { + assert!( + runtime + .mark_stage1_job_succeeded( + claim.thread.id, + claim.ownership_token.as_str(), + claim.thread.updated_at.timestamp(), + "raw", + "summary", + None, + ) + .await + .expect("mark first-batch stage1 success"), + "first batch stage1 completion should succeed" + ); + } + + let second_claims = runtime + .claim_stage1_jobs_for_startup( + current_thread_id, + Stage1StartupClaimParams { + scan_limit: 5_000, + max_claimed: 64, + max_age_days: 30, + min_rollout_idle_hours: 12, + allowed_sources: allowed_sources.as_slice(), + lease_seconds: 3_600, + }, + ) + .await + .expect("second stage1 startup claim"); + assert_eq!(second_claims.len(), 64); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn stage1_output_cascades_on_thread_delete() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let cwd = codex_home.join("workspace"); + runtime + .upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd)) + .await + .expect("upsert thread"); + + let claim = runtime + .try_claim_stage1_job(thread_id, owner, 100, 3600, 64) + .await + .expect("claim stage1"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + ownership_token.as_str(), + 100, + "raw", + "sum", + None, + ) + .await + .expect("mark stage1 succeeded"), + "mark stage1 succeeded should write stage1_outputs" + ); + + let count_before = + sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?") + .bind(thread_id.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("count before delete") + .try_get::("count") + .expect("count value"); + assert_eq!(count_before, 1); + + sqlx::query("DELETE FROM threads WHERE id = ?") + .bind(thread_id.to_string()) + .execute(runtime.pool.as_ref()) + .await + .expect("delete thread"); + + let count_after = + sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?") + .bind(thread_id.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("count after delete") + .try_get::("count") + .expect("count value"); + assert_eq!(count_after, 0); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn mark_stage1_job_succeeded_no_output_skips_phase2_when_output_was_already_absent() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join("workspace"), + )) + .await + .expect("upsert thread"); + + let claim = runtime + .try_claim_stage1_job(thread_id, owner, 100, 3600, 64) + .await + .expect("claim stage1"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded_no_output(thread_id, ownership_token.as_str()) + .await + .expect("mark stage1 succeeded without output"), + "stage1 no-output success should complete the job" + ); + + let output_row_count = + sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?") + .bind(thread_id.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("load stage1 output count") + .try_get::("count") + .expect("stage1 output count"); + assert_eq!( + output_row_count, 0, + "stage1 no-output success should not persist empty stage1 outputs" + ); + + let up_to_date = runtime + .try_claim_stage1_job(thread_id, owner_b, 100, 3600, 64) + .await + .expect("claim stage1 up-to-date"); + assert_eq!(up_to_date, Stage1JobClaimOutcome::SkippedUpToDate); + + let global_job_row_count = sqlx::query("SELECT COUNT(*) AS count FROM jobs WHERE kind = ?") + .bind("memory_consolidate_global") + .fetch_one(runtime.pool.as_ref()) + .await + .expect("load phase2 job row count") + .try_get::("count") + .expect("phase2 job row count"); + assert_eq!( + global_job_row_count, 0, + "no-output without an existing stage1 output should not enqueue phase2" + ); + + let claim_phase2 = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim phase2"); + assert_eq!( + claim_phase2, + Phase2JobClaimOutcome::SkippedNotDirty, + "phase2 should remain clean when no-output deleted nothing" + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn mark_stage1_job_succeeded_no_output_enqueues_phase2_when_deleting_output() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join("workspace"), + )) + .await + .expect("upsert thread"); + + let first_claim = runtime + .try_claim_stage1_job(thread_id, owner, 100, 3600, 64) + .await + .expect("claim initial stage1"); + let first_token = match first_claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected initial stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded(thread_id, first_token.as_str(), 100, "raw", "sum", None) + .await + .expect("mark initial stage1 succeeded"), + "initial stage1 success should create stage1 output" + ); + + let phase2_claim = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim phase2 after initial output"); + let (phase2_token, phase2_input_watermark) = match phase2_claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected phase2 claim after initial output: {other:?}"), + }; + assert_eq!(phase2_input_watermark, 100); + assert!( + runtime + .mark_global_phase2_job_succeeded( + phase2_token.as_str(), + phase2_input_watermark, + &[], + ) + .await + .expect("mark initial phase2 succeeded"), + "initial phase2 success should clear global dirty state" + ); + + let no_output_claim = runtime + .try_claim_stage1_job(thread_id, owner_b, 101, 3600, 64) + .await + .expect("claim stage1 for no-output delete"); + let no_output_token = match no_output_claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected no-output stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded_no_output(thread_id, no_output_token.as_str()) + .await + .expect("mark stage1 no-output after existing output"), + "no-output should succeed when deleting an existing stage1 output" + ); + + let output_row_count = + sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?") + .bind(thread_id.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("load stage1 output count after delete") + .try_get::("count") + .expect("stage1 output count"); + assert_eq!(output_row_count, 0); + + let claim_phase2 = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim phase2 after no-output deletion"); + let (phase2_token, phase2_input_watermark) = match claim_phase2 { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected phase2 claim after no-output deletion: {other:?}"), + }; + assert_eq!(phase2_input_watermark, 101); + assert!( + runtime + .mark_global_phase2_job_succeeded( + phase2_token.as_str(), + phase2_input_watermark, + &[], + ) + .await + .expect("mark phase2 succeeded after no-output delete") + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn stage1_retry_exhaustion_does_not_block_newer_watermark() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join("workspace"), + )) + .await + .expect("upsert thread"); + + for attempt in 0..3 { + let claim = runtime + .try_claim_stage1_job(thread_id, owner, 100, 3_600, 64) + .await + .expect("claim stage1 for retry exhaustion"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!( + "attempt {} should claim stage1 before retries are exhausted: {other:?}", + attempt + 1 + ), + }; + assert!( + runtime + .mark_stage1_job_failed(thread_id, ownership_token.as_str(), "boom", 0) + .await + .expect("mark stage1 failed"), + "attempt {} should decrement retry budget", + attempt + 1 + ); + } + + let exhausted_claim = runtime + .try_claim_stage1_job(thread_id, owner, 100, 3_600, 64) + .await + .expect("claim stage1 after retry exhaustion"); + assert_eq!( + exhausted_claim, + Stage1JobClaimOutcome::SkippedRetryExhausted + ); + + let newer_source_claim = runtime + .try_claim_stage1_job(thread_id, owner, 101, 3_600, 64) + .await + .expect("claim stage1 with newer source watermark"); + assert!( + matches!(newer_source_claim, Stage1JobClaimOutcome::Claimed { .. }), + "newer source watermark should reset retry budget and be claimable" + ); + + let job_row = sqlx::query( + "SELECT retry_remaining, input_watermark FROM jobs WHERE kind = ? AND job_key = ?", + ) + .bind("memory_stage1") + .bind(thread_id.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("load stage1 job row after newer-source claim"); + assert_eq!( + job_row + .try_get::("retry_remaining") + .expect("retry_remaining"), + 3 + ); + assert_eq!( + job_row + .try_get::("input_watermark") + .expect("input_watermark"), + 101 + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn phase2_global_consolidation_reruns_when_watermark_advances() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + + runtime + .enqueue_global_consolidation(100) + .await + .expect("enqueue global consolidation"); + + let claim = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim phase2"); + let (ownership_token, input_watermark) = match claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected phase2 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_global_phase2_job_succeeded(ownership_token.as_str(), input_watermark, &[],) + .await + .expect("mark phase2 succeeded"), + "phase2 success should finalize for current token" + ); + + let claim_up_to_date = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim phase2 up-to-date"); + assert_eq!(claim_up_to_date, Phase2JobClaimOutcome::SkippedNotDirty); + + runtime + .enqueue_global_consolidation(101) + .await + .expect("enqueue global consolidation again"); + + let claim_rerun = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim phase2 rerun"); + assert!( + matches!(claim_rerun, Phase2JobClaimOutcome::Claimed { .. }), + "advanced watermark should be claimable" + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn list_stage1_outputs_for_global_returns_latest_outputs() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let thread_id_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id_a, + codex_home.join("workspace-a"), + )) + .await + .expect("upsert thread a"); + let mut metadata_b = + test_thread_metadata(&codex_home, thread_id_b, codex_home.join("workspace-b")); + metadata_b.git_branch = Some("feature/stage1-b".to_string()); + runtime + .upsert_thread(&metadata_b) + .await + .expect("upsert thread b"); + + let claim = runtime + .try_claim_stage1_job(thread_id_a, owner, 100, 3600, 64) + .await + .expect("claim stage1 a"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id_a, + ownership_token.as_str(), + 100, + "raw memory a", + "summary a", + None, + ) + .await + .expect("mark stage1 succeeded a"), + "stage1 success should persist output a" + ); + + let claim = runtime + .try_claim_stage1_job(thread_id_b, owner, 101, 3600, 64) + .await + .expect("claim stage1 b"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id_b, + ownership_token.as_str(), + 101, + "raw memory b", + "summary b", + Some("rollout-b"), + ) + .await + .expect("mark stage1 succeeded b"), + "stage1 success should persist output b" + ); + + let outputs = runtime + .list_stage1_outputs_for_global(10) + .await + .expect("list stage1 outputs for global"); + assert_eq!(outputs.len(), 2); + assert_eq!(outputs[0].thread_id, thread_id_b); + assert_eq!(outputs[0].rollout_summary, "summary b"); + assert_eq!(outputs[0].rollout_slug.as_deref(), Some("rollout-b")); + assert_eq!(outputs[0].cwd, codex_home.join("workspace-b")); + assert_eq!(outputs[0].git_branch.as_deref(), Some("feature/stage1-b")); + assert_eq!(outputs[1].thread_id, thread_id_a); + assert_eq!(outputs[1].rollout_summary, "summary a"); + assert_eq!(outputs[1].rollout_slug, None); + assert_eq!(outputs[1].cwd, codex_home.join("workspace-a")); + assert_eq!(outputs[1].git_branch, None); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn list_stage1_outputs_for_global_skips_empty_payloads() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id_non_empty = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let thread_id_empty = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id_non_empty, + codex_home.join("workspace-non-empty"), + )) + .await + .expect("upsert non-empty thread"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id_empty, + codex_home.join("workspace-empty"), + )) + .await + .expect("upsert empty thread"); + + sqlx::query( + r#" +INSERT INTO stage1_outputs (thread_id, source_updated_at, raw_memory, rollout_summary, generated_at) +VALUES (?, ?, ?, ?, ?) + "#, + ) + .bind(thread_id_non_empty.to_string()) + .bind(100_i64) + .bind("raw memory") + .bind("summary") + .bind(100_i64) + .execute(runtime.pool.as_ref()) + .await + .expect("insert non-empty stage1 output"); + sqlx::query( + r#" +INSERT INTO stage1_outputs (thread_id, source_updated_at, raw_memory, rollout_summary, generated_at) +VALUES (?, ?, ?, ?, ?) + "#, + ) + .bind(thread_id_empty.to_string()) + .bind(101_i64) + .bind("") + .bind("") + .bind(101_i64) + .execute(runtime.pool.as_ref()) + .await + .expect("insert empty stage1 output"); + + let outputs = runtime + .list_stage1_outputs_for_global(1) + .await + .expect("list stage1 outputs for global"); + assert_eq!(outputs.len(), 1); + assert_eq!(outputs[0].thread_id, thread_id_non_empty); + assert_eq!(outputs[0].rollout_summary, "summary"); + assert_eq!(outputs[0].cwd, codex_home.join("workspace-non-empty")); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn get_phase2_input_selection_reports_added_retained_and_removed_rows() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let thread_id_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let thread_id_c = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + + for (thread_id, workspace) in [ + (thread_id_a, "workspace-a"), + (thread_id_b, "workspace-b"), + (thread_id_c, "workspace-c"), + ] { + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join(workspace), + )) + .await + .expect("upsert thread"); + } + + for (thread_id, updated_at, slug) in [ + (thread_id_a, 100, Some("rollout-a")), + (thread_id_b, 101, Some("rollout-b")), + (thread_id_c, 102, Some("rollout-c")), + ] { + let claim = runtime + .try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64) + .await + .expect("claim stage1"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + ownership_token.as_str(), + updated_at, + &format!("raw-{updated_at}"), + &format!("summary-{updated_at}"), + slug, + ) + .await + .expect("mark stage1 succeeded"), + "stage1 success should persist output" + ); + } + + let claim = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim phase2"); + let (ownership_token, input_watermark) = match claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected phase2 claim outcome: {other:?}"), + }; + assert_eq!(input_watermark, 102); + let selected_outputs = runtime + .list_stage1_outputs_for_global(10) + .await + .expect("list stage1 outputs for global") + .into_iter() + .filter(|output| output.thread_id == thread_id_c || output.thread_id == thread_id_a) + .collect::>(); + assert!( + runtime + .mark_global_phase2_job_succeeded( + ownership_token.as_str(), + input_watermark, + &selected_outputs, + ) + .await + .expect("mark phase2 success with selection"), + "phase2 success should persist selected rows" + ); + + let selection = runtime + .get_phase2_input_selection(2, 36_500) + .await + .expect("load phase2 input selection"); + + assert_eq!(selection.selected.len(), 2); + assert_eq!(selection.previous_selected.len(), 2); + assert_eq!(selection.selected[0].thread_id, thread_id_c); + assert_eq!( + selection.selected[0].rollout_path, + codex_home.join(format!("rollout-{thread_id_c}.jsonl")) + ); + assert_eq!(selection.selected[1].thread_id, thread_id_b); + assert_eq!(selection.retained_thread_ids, vec![thread_id_c]); + + assert_eq!(selection.removed.len(), 1); + assert_eq!(selection.removed[0].thread_id, thread_id_a); + assert_eq!( + selection.removed[0].rollout_slug.as_deref(), + Some("rollout-a") + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn get_phase2_input_selection_treats_regenerated_selected_rows_as_added() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join("workspace"), + )) + .await + .expect("upsert thread"); + + let first_claim = runtime + .try_claim_stage1_job(thread_id, owner, 100, 3600, 64) + .await + .expect("claim initial stage1"); + let first_token = match first_claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + first_token.as_str(), + 100, + "raw-100", + "summary-100", + Some("rollout-100"), + ) + .await + .expect("mark initial stage1 success"), + "initial stage1 success should persist output" + ); + + let phase2_claim = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim phase2"); + let (phase2_token, input_watermark) = match phase2_claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected phase2 claim outcome: {other:?}"), + }; + let selected_outputs = runtime + .list_stage1_outputs_for_global(1) + .await + .expect("list selected outputs"); + assert!( + runtime + .mark_global_phase2_job_succeeded( + phase2_token.as_str(), + input_watermark, + &selected_outputs, + ) + .await + .expect("mark phase2 success"), + "phase2 success should persist selected rows" + ); + + let refreshed_claim = runtime + .try_claim_stage1_job(thread_id, owner, 101, 3600, 64) + .await + .expect("claim refreshed stage1"); + let refreshed_token = match refreshed_claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + refreshed_token.as_str(), + 101, + "raw-101", + "summary-101", + Some("rollout-101"), + ) + .await + .expect("mark refreshed stage1 success"), + "refreshed stage1 success should persist output" + ); + + let selection = runtime + .get_phase2_input_selection(1, 36_500) + .await + .expect("load phase2 input selection"); + assert_eq!(selection.selected.len(), 1); + assert_eq!(selection.previous_selected.len(), 1); + assert_eq!(selection.selected[0].thread_id, thread_id); + assert_eq!(selection.selected[0].source_updated_at.timestamp(), 101); + assert!(selection.retained_thread_ids.is_empty()); + assert!(selection.removed.is_empty()); + + let (selected_for_phase2, selected_for_phase2_source_updated_at) = + sqlx::query_as::<_, (i64, Option)>( + "SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?", + ) + .bind(thread_id.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("load selected_for_phase2"); + assert_eq!(selected_for_phase2, 1); + assert_eq!(selected_for_phase2_source_updated_at, Some(100)); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn get_phase2_input_selection_reports_regenerated_previous_selection_as_removed() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread a"); + let thread_id_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread b"); + let thread_id_c = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread c"); + let thread_id_d = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread d"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + + for (thread_id, workspace) in [ + (thread_id_a, "workspace-a"), + (thread_id_b, "workspace-b"), + (thread_id_c, "workspace-c"), + (thread_id_d, "workspace-d"), + ] { + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join(workspace), + )) + .await + .expect("upsert thread"); + } + + for (thread_id, updated_at, slug) in [ + (thread_id_a, 100, Some("rollout-a-100")), + (thread_id_b, 101, Some("rollout-b-101")), + (thread_id_c, 99, Some("rollout-c-99")), + (thread_id_d, 98, Some("rollout-d-98")), + ] { + let claim = runtime + .try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64) + .await + .expect("claim initial stage1"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + ownership_token.as_str(), + updated_at, + &format!("raw-{updated_at}"), + &format!("summary-{updated_at}"), + slug, + ) + .await + .expect("mark stage1 succeeded"), + "stage1 success should persist output" + ); + } + + let phase2_claim = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim phase2"); + let (phase2_token, input_watermark) = match phase2_claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected phase2 claim outcome: {other:?}"), + }; + let selected_outputs = runtime + .list_stage1_outputs_for_global(2) + .await + .expect("list selected outputs"); + assert_eq!( + selected_outputs + .iter() + .map(|output| output.thread_id) + .collect::>(), + vec![thread_id_b, thread_id_a] + ); + assert!( + runtime + .mark_global_phase2_job_succeeded( + phase2_token.as_str(), + input_watermark, + &selected_outputs, + ) + .await + .expect("mark phase2 success"), + "phase2 success should persist selected rows" + ); + + for (thread_id, updated_at, slug) in [ + (thread_id_a, 102, Some("rollout-a-102")), + (thread_id_c, 103, Some("rollout-c-103")), + (thread_id_d, 104, Some("rollout-d-104")), + ] { + let claim = runtime + .try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64) + .await + .expect("claim refreshed stage1"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + ownership_token.as_str(), + updated_at, + &format!("raw-{updated_at}"), + &format!("summary-{updated_at}"), + slug, + ) + .await + .expect("mark refreshed stage1 success"), + "refreshed stage1 success should persist output" + ); + } + + let selection = runtime + .get_phase2_input_selection(2, 36_500) + .await + .expect("load phase2 input selection"); + assert_eq!( + selection + .selected + .iter() + .map(|output| output.thread_id) + .collect::>(), + vec![thread_id_d, thread_id_c] + ); + assert_eq!( + selection + .previous_selected + .iter() + .map(|output| output.thread_id) + .collect::>(), + vec![thread_id_a, thread_id_b] + ); + assert!(selection.retained_thread_ids.is_empty()); + assert_eq!( + selection + .removed + .iter() + .map(|output| (output.thread_id, output.source_updated_at.timestamp())) + .collect::>(), + vec![(thread_id_a, 102), (thread_id_b, 101)] + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn mark_global_phase2_job_succeeded_updates_selected_snapshot_timestamp() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join("workspace"), + )) + .await + .expect("upsert thread"); + + let initial_claim = runtime + .try_claim_stage1_job(thread_id, owner, 100, 3600, 64) + .await + .expect("claim initial stage1"); + let initial_token = match initial_claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + initial_token.as_str(), + 100, + "raw-100", + "summary-100", + Some("rollout-100"), + ) + .await + .expect("mark initial stage1 success"), + "initial stage1 success should persist output" + ); + + let first_phase2_claim = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim first phase2"); + let (first_phase2_token, first_input_watermark) = match first_phase2_claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected first phase2 claim outcome: {other:?}"), + }; + let first_selected_outputs = runtime + .list_stage1_outputs_for_global(1) + .await + .expect("list first selected outputs"); + assert!( + runtime + .mark_global_phase2_job_succeeded( + first_phase2_token.as_str(), + first_input_watermark, + &first_selected_outputs, + ) + .await + .expect("mark first phase2 success"), + "first phase2 success should persist selected rows" + ); + + let refreshed_claim = runtime + .try_claim_stage1_job(thread_id, owner, 101, 3600, 64) + .await + .expect("claim refreshed stage1"); + let refreshed_token = match refreshed_claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected refreshed stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + refreshed_token.as_str(), + 101, + "raw-101", + "summary-101", + Some("rollout-101"), + ) + .await + .expect("mark refreshed stage1 success"), + "refreshed stage1 success should persist output" + ); + + let second_phase2_claim = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim second phase2"); + let (second_phase2_token, second_input_watermark) = match second_phase2_claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected second phase2 claim outcome: {other:?}"), + }; + let second_selected_outputs = runtime + .list_stage1_outputs_for_global(1) + .await + .expect("list second selected outputs"); + assert_eq!( + second_selected_outputs[0].source_updated_at.timestamp(), + 101 + ); + assert!( + runtime + .mark_global_phase2_job_succeeded( + second_phase2_token.as_str(), + second_input_watermark, + &second_selected_outputs, + ) + .await + .expect("mark second phase2 success"), + "second phase2 success should persist selected rows" + ); + + let selection = runtime + .get_phase2_input_selection(1, 36_500) + .await + .expect("load phase2 input selection after refresh"); + assert_eq!(selection.retained_thread_ids, vec![thread_id]); + + let (selected_for_phase2, selected_for_phase2_source_updated_at) = + sqlx::query_as::<_, (i64, Option)>( + "SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?", + ) + .bind(thread_id.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("load selected snapshot after phase2"); + assert_eq!(selected_for_phase2, 1); + assert_eq!(selected_for_phase2_source_updated_at, Some(101)); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn mark_global_phase2_job_succeeded_only_marks_exact_selected_snapshots() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join("workspace"), + )) + .await + .expect("upsert thread"); + + let initial_claim = runtime + .try_claim_stage1_job(thread_id, owner, 100, 3600, 64) + .await + .expect("claim initial stage1"); + let initial_token = match initial_claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + initial_token.as_str(), + 100, + "raw-100", + "summary-100", + Some("rollout-100"), + ) + .await + .expect("mark initial stage1 success"), + "initial stage1 success should persist output" + ); + + let phase2_claim = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim phase2"); + let (phase2_token, input_watermark) = match phase2_claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected phase2 claim outcome: {other:?}"), + }; + let selected_outputs = runtime + .list_stage1_outputs_for_global(1) + .await + .expect("list selected outputs"); + assert_eq!(selected_outputs[0].source_updated_at.timestamp(), 100); + + let refreshed_claim = runtime + .try_claim_stage1_job(thread_id, owner, 101, 3600, 64) + .await + .expect("claim refreshed stage1"); + let refreshed_token = match refreshed_claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + refreshed_token.as_str(), + 101, + "raw-101", + "summary-101", + Some("rollout-101"), + ) + .await + .expect("mark refreshed stage1 success"), + "refreshed stage1 success should persist output" + ); + + assert!( + runtime + .mark_global_phase2_job_succeeded( + phase2_token.as_str(), + input_watermark, + &selected_outputs, + ) + .await + .expect("mark phase2 success"), + "phase2 success should still complete" + ); + + let (selected_for_phase2, selected_for_phase2_source_updated_at) = + sqlx::query_as::<_, (i64, Option)>( + "SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?", + ) + .bind(thread_id.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("load selected_for_phase2"); + assert_eq!(selected_for_phase2, 0); + assert_eq!(selected_for_phase2_source_updated_at, None); + + let selection = runtime + .get_phase2_input_selection(1, 36_500) + .await + .expect("load phase2 input selection"); + assert_eq!(selection.selected.len(), 1); + assert_eq!(selection.selected[0].source_updated_at.timestamp(), 101); + assert!(selection.retained_thread_ids.is_empty()); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn record_stage1_output_usage_updates_usage_metadata() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id a"); + let thread_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id b"); + let missing = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("missing id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_a, + codex_home.join("workspace-a"), + )) + .await + .expect("upsert thread a"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_b, + codex_home.join("workspace-b"), + )) + .await + .expect("upsert thread b"); + + let claim_a = runtime + .try_claim_stage1_job(thread_a, owner, 100, 3600, 64) + .await + .expect("claim stage1 a"); + let token_a = match claim_a { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome for a: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded(thread_a, token_a.as_str(), 100, "raw a", "sum a", None) + .await + .expect("mark stage1 succeeded a") + ); + + let claim_b = runtime + .try_claim_stage1_job(thread_b, owner, 101, 3600, 64) + .await + .expect("claim stage1 b"); + let token_b = match claim_b { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome for b: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded(thread_b, token_b.as_str(), 101, "raw b", "sum b", None) + .await + .expect("mark stage1 succeeded b") + ); + + let updated_rows = runtime + .record_stage1_output_usage(&[thread_a, thread_a, thread_b, missing]) + .await + .expect("record stage1 output usage"); + assert_eq!(updated_rows, 3); + + let row_a = + sqlx::query("SELECT usage_count, last_usage FROM stage1_outputs WHERE thread_id = ?") + .bind(thread_a.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("load stage1 usage row a"); + let row_b = + sqlx::query("SELECT usage_count, last_usage FROM stage1_outputs WHERE thread_id = ?") + .bind(thread_b.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("load stage1 usage row b"); + + assert_eq!( + row_a + .try_get::("usage_count") + .expect("usage_count a"), + 2 + ); + assert_eq!( + row_b + .try_get::("usage_count") + .expect("usage_count b"), + 1 + ); + + let last_usage_a = row_a.try_get::("last_usage").expect("last_usage a"); + let last_usage_b = row_b.try_get::("last_usage").expect("last_usage b"); + assert_eq!(last_usage_a, last_usage_b); + assert!(last_usage_a > 0); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn get_phase2_input_selection_prioritizes_usage_count_then_recent_usage() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let now = Utc::now(); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let thread_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id a"); + let thread_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id b"); + let thread_c = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id c"); + + for (thread_id, workspace) in [ + (thread_a, "workspace-a"), + (thread_b, "workspace-b"), + (thread_c, "workspace-c"), + ] { + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join(workspace), + )) + .await + .expect("upsert thread"); + } + + for (thread_id, generated_at, summary) in [ + (thread_a, now - Duration::days(3), "summary-a"), + (thread_b, now - Duration::days(2), "summary-b"), + (thread_c, now - Duration::days(1), "summary-c"), + ] { + let source_updated_at = generated_at.timestamp(); + let claim = runtime + .try_claim_stage1_job(thread_id, owner, source_updated_at, 3600, 64) + .await + .expect("claim stage1"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + ownership_token.as_str(), + source_updated_at, + &format!("raw-{summary}"), + summary, + None, + ) + .await + .expect("mark stage1 success"), + "stage1 success should persist output" + ); + } + + for (thread_id, usage_count, last_usage) in [ + (thread_a, 5_i64, now - Duration::days(10)), + (thread_b, 5_i64, now - Duration::days(1)), + (thread_c, 1_i64, now - Duration::hours(1)), + ] { + sqlx::query( + "UPDATE stage1_outputs SET usage_count = ?, last_usage = ? WHERE thread_id = ?", + ) + .bind(usage_count) + .bind(last_usage.timestamp()) + .bind(thread_id.to_string()) + .execute(runtime.pool.as_ref()) + .await + .expect("update usage metadata"); + } + + let selection = runtime + .get_phase2_input_selection(3, 30) + .await + .expect("load phase2 input selection"); + + assert_eq!( + selection + .selected + .iter() + .map(|output| output.thread_id) + .collect::>(), + vec![thread_b, thread_a, thread_c] + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn get_phase2_input_selection_excludes_stale_used_memories_but_keeps_fresh_never_used() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let now = Utc::now(); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let thread_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id a"); + let thread_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id b"); + let thread_c = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id c"); + + for (thread_id, workspace) in [ + (thread_a, "workspace-a"), + (thread_b, "workspace-b"), + (thread_c, "workspace-c"), + ] { + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join(workspace), + )) + .await + .expect("upsert thread"); + } + + for (thread_id, generated_at, summary) in [ + (thread_a, now - Duration::days(40), "summary-a"), + (thread_b, now - Duration::days(2), "summary-b"), + (thread_c, now - Duration::days(50), "summary-c"), + ] { + let source_updated_at = generated_at.timestamp(); + let claim = runtime + .try_claim_stage1_job(thread_id, owner, source_updated_at, 3600, 64) + .await + .expect("claim stage1"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + ownership_token.as_str(), + source_updated_at, + &format!("raw-{summary}"), + summary, + None, + ) + .await + .expect("mark stage1 success"), + "stage1 success should persist output" + ); + } + + for (thread_id, usage_count, last_usage) in [ + (thread_a, Some(9_i64), Some(now - Duration::days(31))), + (thread_b, None, None), + (thread_c, Some(1_i64), Some(now - Duration::days(1))), + ] { + sqlx::query( + "UPDATE stage1_outputs SET usage_count = ?, last_usage = ? WHERE thread_id = ?", + ) + .bind(usage_count) + .bind(last_usage.map(|value| value.timestamp())) + .bind(thread_id.to_string()) + .execute(runtime.pool.as_ref()) + .await + .expect("update usage metadata"); + } + + let selection = runtime + .get_phase2_input_selection(3, 30) + .await + .expect("load phase2 input selection"); + + assert_eq!( + selection + .selected + .iter() + .map(|output| output.thread_id) + .collect::>(), + vec![thread_c, thread_b] + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn get_phase2_input_selection_prefers_recent_thread_updates_over_recent_generation() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let older_thread = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("older thread id"); + let newer_thread = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("newer thread id"); + + for (thread_id, workspace) in [ + (older_thread, "workspace-older"), + (newer_thread, "workspace-newer"), + ] { + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join(workspace), + )) + .await + .expect("upsert thread"); + } + + for (thread_id, source_updated_at, summary) in [ + (older_thread, 100_i64, "summary-older"), + (newer_thread, 200_i64, "summary-newer"), + ] { + let claim = runtime + .try_claim_stage1_job(thread_id, owner, source_updated_at, 3600, 64) + .await + .expect("claim stage1"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + ownership_token.as_str(), + source_updated_at, + &format!("raw-{summary}"), + summary, + None, + ) + .await + .expect("mark stage1 success"), + "stage1 success should persist output" + ); + } + + sqlx::query("UPDATE stage1_outputs SET generated_at = ? WHERE thread_id = ?") + .bind(300_i64) + .bind(older_thread.to_string()) + .execute(runtime.pool.as_ref()) + .await + .expect("update older generated_at"); + sqlx::query("UPDATE stage1_outputs SET generated_at = ? WHERE thread_id = ?") + .bind(150_i64) + .bind(newer_thread.to_string()) + .execute(runtime.pool.as_ref()) + .await + .expect("update newer generated_at"); + + let selection = runtime + .get_phase2_input_selection(1, 36_500) + .await + .expect("load phase2 input selection"); + + assert_eq!(selection.selected.len(), 1); + assert_eq!(selection.selected[0].thread_id, newer_thread); + assert_eq!(selection.selected[0].source_updated_at.timestamp(), 200); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn mark_stage1_job_succeeded_enqueues_global_consolidation() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id a"); + let thread_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id b"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_a, + codex_home.join("workspace-a"), + )) + .await + .expect("upsert thread a"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_b, + codex_home.join("workspace-b"), + )) + .await + .expect("upsert thread b"); + + let claim_a = runtime + .try_claim_stage1_job(thread_a, owner, 100, 3600, 64) + .await + .expect("claim stage1 a"); + let token_a = match claim_a { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome for thread a: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_a, + token_a.as_str(), + 100, + "raw-a", + "summary-a", + None, + ) + .await + .expect("mark stage1 succeeded a"), + "stage1 success should persist output for thread a" + ); + + let claim_b = runtime + .try_claim_stage1_job(thread_b, owner, 101, 3600, 64) + .await + .expect("claim stage1 b"); + let token_b = match claim_b { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome for thread b: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_b, + token_b.as_str(), + 101, + "raw-b", + "summary-b", + None, + ) + .await + .expect("mark stage1 succeeded b"), + "stage1 success should persist output for thread b" + ); + + let claim = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim global consolidation"); + let input_watermark = match claim { + Phase2JobClaimOutcome::Claimed { + input_watermark, .. + } => input_watermark, + other => panic!("unexpected global consolidation claim outcome: {other:?}"), + }; + assert_eq!(input_watermark, 101); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn phase2_global_lock_allows_only_one_fresh_runner() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + runtime + .enqueue_global_consolidation(200) + .await + .expect("enqueue global consolidation"); + + let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner a"); + let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner b"); + + let running_claim = runtime + .try_claim_global_phase2_job(owner_a, 3600) + .await + .expect("claim global lock"); + assert!( + matches!(running_claim, Phase2JobClaimOutcome::Claimed { .. }), + "first owner should claim global lock" + ); + + let second_claim = runtime + .try_claim_global_phase2_job(owner_b, 3600) + .await + .expect("claim global lock from second owner"); + assert_eq!(second_claim, Phase2JobClaimOutcome::SkippedRunning); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn phase2_global_lock_stale_lease_allows_takeover() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + runtime + .enqueue_global_consolidation(300) + .await + .expect("enqueue global consolidation"); + + let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner a"); + let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner b"); + + let initial_claim = runtime + .try_claim_global_phase2_job(owner_a, 3600) + .await + .expect("claim initial global lock"); + let token_a = match initial_claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, .. + } => ownership_token, + other => panic!("unexpected initial claim outcome: {other:?}"), + }; + + sqlx::query("UPDATE jobs SET lease_until = ? WHERE kind = ? AND job_key = ?") + .bind(Utc::now().timestamp() - 1) + .bind("memory_consolidate_global") + .bind("global") + .execute(runtime.pool.as_ref()) + .await + .expect("expire global consolidation lease"); + + let takeover_claim = runtime + .try_claim_global_phase2_job(owner_b, 3600) + .await + .expect("claim stale global lock"); + let (token_b, input_watermark) = match takeover_claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected takeover claim outcome: {other:?}"), + }; + assert_ne!(token_a, token_b); + assert_eq!(input_watermark, 300); + + assert_eq!( + runtime + .mark_global_phase2_job_succeeded(token_a.as_str(), 300, &[]) + .await + .expect("mark stale owner success result"), + false, + "stale owner should lose finalization ownership after takeover" + ); + assert!( + runtime + .mark_global_phase2_job_succeeded(token_b.as_str(), 300, &[]) + .await + .expect("mark takeover owner success"), + "takeover owner should finalize consolidation" + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn phase2_backfilled_inputs_below_last_success_still_become_dirty() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + runtime + .enqueue_global_consolidation(500) + .await + .expect("enqueue initial consolidation"); + let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner a"); + let claim_a = runtime + .try_claim_global_phase2_job(owner_a, 3_600) + .await + .expect("claim initial consolidation"); + let token_a = match claim_a { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => { + assert_eq!(input_watermark, 500); + ownership_token + } + other => panic!("unexpected initial phase2 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_global_phase2_job_succeeded(token_a.as_str(), 500, &[]) + .await + .expect("mark initial phase2 success"), + "initial phase2 success should finalize" + ); + + runtime + .enqueue_global_consolidation(400) + .await + .expect("enqueue backfilled consolidation"); + + let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner b"); + let claim_b = runtime + .try_claim_global_phase2_job(owner_b, 3_600) + .await + .expect("claim backfilled consolidation"); + match claim_b { + Phase2JobClaimOutcome::Claimed { + input_watermark, .. + } => { + assert!( + input_watermark > 500, + "backfilled enqueue should advance dirty watermark beyond last success" + ); + } + other => panic!("unexpected backfilled phase2 claim outcome: {other:?}"), + } + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn phase2_failure_fallback_updates_unowned_running_job() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + runtime + .enqueue_global_consolidation(400) + .await + .expect("enqueue global consolidation"); + + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner"); + let claim = runtime + .try_claim_global_phase2_job(owner, 3_600) + .await + .expect("claim global consolidation"); + let ownership_token = match claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, .. + } => ownership_token, + other => panic!("unexpected claim outcome: {other:?}"), + }; + + sqlx::query("UPDATE jobs SET ownership_token = NULL WHERE kind = ? AND job_key = ?") + .bind("memory_consolidate_global") + .bind("global") + .execute(runtime.pool.as_ref()) + .await + .expect("clear ownership token"); + + assert_eq!( + runtime + .mark_global_phase2_job_failed(ownership_token.as_str(), "lost", 3_600) + .await + .expect("mark phase2 failed with strict ownership"), + false, + "strict failure update should not match unowned running job" + ); + assert!( + runtime + .mark_global_phase2_job_failed_if_unowned(ownership_token.as_str(), "lost", 3_600) + .await + .expect("fallback failure update should match unowned running job"), + "fallback failure update should transition the unowned running job" + ); + + let claim = runtime + .try_claim_global_phase2_job(ThreadId::new(), 3_600) + .await + .expect("claim after fallback failure"); + assert_eq!(claim, Phase2JobClaimOutcome::SkippedNotDirty); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } +} diff --git a/codex-rs/state/src/runtime/test_support.rs b/codex-rs/state/src/runtime/test_support.rs new file mode 100644 index 00000000000..d749fe2bfba --- /dev/null +++ b/codex-rs/state/src/runtime/test_support.rs @@ -0,0 +1,64 @@ +#[cfg(test)] +use chrono::DateTime; +#[cfg(test)] +use chrono::Utc; +#[cfg(test)] +use codex_protocol::ThreadId; +#[cfg(test)] +use codex_protocol::protocol::AskForApproval; +#[cfg(test)] +use codex_protocol::protocol::SandboxPolicy; +#[cfg(test)] +use std::path::Path; +#[cfg(test)] +use std::path::PathBuf; +#[cfg(test)] +use std::time::SystemTime; +#[cfg(test)] +use std::time::UNIX_EPOCH; +#[cfg(test)] +use uuid::Uuid; + +#[cfg(test)] +use crate::ThreadMetadata; + +#[cfg(test)] +pub(super) fn unique_temp_dir() -> PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |duration| duration.as_nanos()); + std::env::temp_dir().join(format!( + "codex-state-runtime-test-{nanos}-{}", + Uuid::new_v4() + )) +} + +#[cfg(test)] +pub(super) fn test_thread_metadata( + codex_home: &Path, + thread_id: ThreadId, + cwd: PathBuf, +) -> ThreadMetadata { + let now = DateTime::::from_timestamp(1_700_000_000, 0).expect("timestamp"); + ThreadMetadata { + id: thread_id, + rollout_path: codex_home.join(format!("rollout-{thread_id}.jsonl")), + created_at: now, + updated_at: now, + source: "cli".to_string(), + agent_nickname: None, + agent_role: None, + model_provider: "test-provider".to_string(), + cwd, + cli_version: "0.0.0".to_string(), + title: String::new(), + sandbox_policy: crate::extract::enum_to_string(&SandboxPolicy::new_read_only_policy()), + approval_mode: crate::extract::enum_to_string(&AskForApproval::OnRequest), + tokens_used: 0, + first_user_message: Some("hello".to_string()), + archived_at: None, + git_sha: None, + git_branch: None, + git_origin_url: None, + } +} diff --git a/codex-rs/state/src/runtime/threads.rs b/codex-rs/state/src/runtime/threads.rs new file mode 100644 index 00000000000..0eb64d6be56 --- /dev/null +++ b/codex-rs/state/src/runtime/threads.rs @@ -0,0 +1,496 @@ +use super::*; + +impl StateRuntime { + pub async fn get_thread(&self, id: ThreadId) -> anyhow::Result> { + let row = sqlx::query( + r#" +SELECT + id, + rollout_path, + created_at, + updated_at, + source, + agent_nickname, + agent_role, + model_provider, + cwd, + cli_version, + title, + sandbox_policy, + approval_mode, + tokens_used, + first_user_message, + archived_at, + git_sha, + git_branch, + git_origin_url +FROM threads +WHERE id = ? + "#, + ) + .bind(id.to_string()) + .fetch_optional(self.pool.as_ref()) + .await?; + row.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from)) + .transpose() + } + + /// Get dynamic tools for a thread, if present. + pub async fn get_dynamic_tools( + &self, + thread_id: ThreadId, + ) -> anyhow::Result>> { + let rows = sqlx::query( + r#" +SELECT name, description, input_schema +FROM thread_dynamic_tools +WHERE thread_id = ? +ORDER BY position ASC + "#, + ) + .bind(thread_id.to_string()) + .fetch_all(self.pool.as_ref()) + .await?; + if rows.is_empty() { + return Ok(None); + } + let mut tools = Vec::with_capacity(rows.len()); + for row in rows { + let input_schema: String = row.try_get("input_schema")?; + let input_schema = serde_json::from_str::(input_schema.as_str())?; + tools.push(DynamicToolSpec { + name: row.try_get("name")?, + description: row.try_get("description")?, + input_schema, + }); + } + Ok(Some(tools)) + } + + /// Find a rollout path by thread id using the underlying database. + pub async fn find_rollout_path_by_id( + &self, + id: ThreadId, + archived_only: Option, + ) -> anyhow::Result> { + let mut builder = + QueryBuilder::::new("SELECT rollout_path FROM threads WHERE id = "); + builder.push_bind(id.to_string()); + match archived_only { + Some(true) => { + builder.push(" AND archived = 1"); + } + Some(false) => { + builder.push(" AND archived = 0"); + } + None => {} + } + let row = builder.build().fetch_optional(self.pool.as_ref()).await?; + Ok(row + .and_then(|r| r.try_get::("rollout_path").ok()) + .map(PathBuf::from)) + } + + /// List threads using the underlying database. + #[allow(clippy::too_many_arguments)] + pub async fn list_threads( + &self, + page_size: usize, + anchor: Option<&crate::Anchor>, + sort_key: crate::SortKey, + allowed_sources: &[String], + model_providers: Option<&[String]>, + archived_only: bool, + search_term: Option<&str>, + ) -> anyhow::Result { + let limit = page_size.saturating_add(1); + + let mut builder = QueryBuilder::::new( + r#" +SELECT + id, + rollout_path, + created_at, + updated_at, + source, + agent_nickname, + agent_role, + model_provider, + cwd, + cli_version, + title, + sandbox_policy, + approval_mode, + tokens_used, + first_user_message, + archived_at, + git_sha, + git_branch, + git_origin_url +FROM threads + "#, + ); + push_thread_filters( + &mut builder, + archived_only, + allowed_sources, + model_providers, + anchor, + sort_key, + search_term, + ); + push_thread_order_and_limit(&mut builder, sort_key, limit); + + let rows = builder.build().fetch_all(self.pool.as_ref()).await?; + let mut items = rows + .into_iter() + .map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from)) + .collect::, _>>()?; + let num_scanned_rows = items.len(); + let next_anchor = if items.len() > page_size { + items.pop(); + items + .last() + .and_then(|item| anchor_from_item(item, sort_key)) + } else { + None + }; + Ok(ThreadsPage { + items, + next_anchor, + num_scanned_rows, + }) + } + + /// List thread ids using the underlying database (no rollout scanning). + pub async fn list_thread_ids( + &self, + limit: usize, + anchor: Option<&crate::Anchor>, + sort_key: crate::SortKey, + allowed_sources: &[String], + model_providers: Option<&[String]>, + archived_only: bool, + ) -> anyhow::Result> { + let mut builder = QueryBuilder::::new("SELECT id FROM threads"); + push_thread_filters( + &mut builder, + archived_only, + allowed_sources, + model_providers, + anchor, + sort_key, + None, + ); + push_thread_order_and_limit(&mut builder, sort_key, limit); + + let rows = builder.build().fetch_all(self.pool.as_ref()).await?; + rows.into_iter() + .map(|row| { + let id: String = row.try_get("id")?; + Ok(ThreadId::try_from(id)?) + }) + .collect() + } + + /// Insert or replace thread metadata directly. + pub async fn upsert_thread(&self, metadata: &crate::ThreadMetadata) -> anyhow::Result<()> { + sqlx::query( + r#" +INSERT INTO threads ( + id, + rollout_path, + created_at, + updated_at, + source, + agent_nickname, + agent_role, + model_provider, + cwd, + cli_version, + title, + sandbox_policy, + approval_mode, + tokens_used, + first_user_message, + archived, + archived_at, + git_sha, + git_branch, + git_origin_url +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +ON CONFLICT(id) DO UPDATE SET + rollout_path = excluded.rollout_path, + created_at = excluded.created_at, + updated_at = excluded.updated_at, + source = excluded.source, + agent_nickname = excluded.agent_nickname, + agent_role = excluded.agent_role, + model_provider = excluded.model_provider, + cwd = excluded.cwd, + cli_version = excluded.cli_version, + title = excluded.title, + sandbox_policy = excluded.sandbox_policy, + approval_mode = excluded.approval_mode, + tokens_used = excluded.tokens_used, + first_user_message = excluded.first_user_message, + archived = excluded.archived, + archived_at = excluded.archived_at, + git_sha = excluded.git_sha, + git_branch = excluded.git_branch, + git_origin_url = excluded.git_origin_url + "#, + ) + .bind(metadata.id.to_string()) + .bind(metadata.rollout_path.display().to_string()) + .bind(datetime_to_epoch_seconds(metadata.created_at)) + .bind(datetime_to_epoch_seconds(metadata.updated_at)) + .bind(metadata.source.as_str()) + .bind(metadata.agent_nickname.as_deref()) + .bind(metadata.agent_role.as_deref()) + .bind(metadata.model_provider.as_str()) + .bind(metadata.cwd.display().to_string()) + .bind(metadata.cli_version.as_str()) + .bind(metadata.title.as_str()) + .bind(metadata.sandbox_policy.as_str()) + .bind(metadata.approval_mode.as_str()) + .bind(metadata.tokens_used) + .bind(metadata.first_user_message.as_deref().unwrap_or_default()) + .bind(metadata.archived_at.is_some()) + .bind(metadata.archived_at.map(datetime_to_epoch_seconds)) + .bind(metadata.git_sha.as_deref()) + .bind(metadata.git_branch.as_deref()) + .bind(metadata.git_origin_url.as_deref()) + .execute(self.pool.as_ref()) + .await?; + Ok(()) + } + + /// Persist dynamic tools for a thread if none have been stored yet. + /// + /// Dynamic tools are defined at thread start and should not change afterward. + /// This only writes the first time we see tools for a given thread. + pub async fn persist_dynamic_tools( + &self, + thread_id: ThreadId, + tools: Option<&[DynamicToolSpec]>, + ) -> anyhow::Result<()> { + let Some(tools) = tools else { + return Ok(()); + }; + if tools.is_empty() { + return Ok(()); + } + let thread_id = thread_id.to_string(); + let mut tx = self.pool.begin().await?; + for (idx, tool) in tools.iter().enumerate() { + let position = i64::try_from(idx).unwrap_or(i64::MAX); + let input_schema = serde_json::to_string(&tool.input_schema)?; + sqlx::query( + r#" +INSERT INTO thread_dynamic_tools ( + thread_id, + position, + name, + description, + input_schema +) VALUES (?, ?, ?, ?, ?) +ON CONFLICT(thread_id, position) DO NOTHING + "#, + ) + .bind(thread_id.as_str()) + .bind(position) + .bind(tool.name.as_str()) + .bind(tool.description.as_str()) + .bind(input_schema) + .execute(&mut *tx) + .await?; + } + tx.commit().await?; + Ok(()) + } + + /// Apply rollout items incrementally using the underlying database. + pub async fn apply_rollout_items( + &self, + builder: &ThreadMetadataBuilder, + items: &[RolloutItem], + otel: Option<&OtelManager>, + ) -> anyhow::Result<()> { + if items.is_empty() { + return Ok(()); + } + let mut metadata = self + .get_thread(builder.id) + .await? + .unwrap_or_else(|| builder.build(&self.default_provider)); + metadata.rollout_path = builder.rollout_path.clone(); + for item in items { + apply_rollout_item(&mut metadata, item, &self.default_provider); + } + if let Some(updated_at) = file_modified_time_utc(builder.rollout_path.as_path()).await { + metadata.updated_at = updated_at; + } + // Keep the thread upsert before dynamic tools to satisfy the foreign key constraint: + // thread_dynamic_tools.thread_id -> threads.id. + if let Err(err) = self.upsert_thread(&metadata).await { + if let Some(otel) = otel { + otel.counter(DB_ERROR_METRIC, 1, &[("stage", "apply_rollout_items")]); + } + return Err(err); + } + let dynamic_tools = extract_dynamic_tools(items); + if let Some(dynamic_tools) = dynamic_tools + && let Err(err) = self + .persist_dynamic_tools(builder.id, dynamic_tools.as_deref()) + .await + { + if let Some(otel) = otel { + otel.counter(DB_ERROR_METRIC, 1, &[("stage", "persist_dynamic_tools")]); + } + return Err(err); + } + Ok(()) + } + + /// Mark a thread as archived using the underlying database. + pub async fn mark_archived( + &self, + thread_id: ThreadId, + rollout_path: &Path, + archived_at: DateTime, + ) -> anyhow::Result<()> { + let Some(mut metadata) = self.get_thread(thread_id).await? else { + return Ok(()); + }; + metadata.archived_at = Some(archived_at); + metadata.rollout_path = rollout_path.to_path_buf(); + if let Some(updated_at) = file_modified_time_utc(rollout_path).await { + metadata.updated_at = updated_at; + } + if metadata.id != thread_id { + warn!( + "thread id mismatch during archive: expected {thread_id}, got {}", + metadata.id + ); + } + self.upsert_thread(&metadata).await + } + + /// Mark a thread as unarchived using the underlying database. + pub async fn mark_unarchived( + &self, + thread_id: ThreadId, + rollout_path: &Path, + ) -> anyhow::Result<()> { + let Some(mut metadata) = self.get_thread(thread_id).await? else { + return Ok(()); + }; + metadata.archived_at = None; + metadata.rollout_path = rollout_path.to_path_buf(); + if let Some(updated_at) = file_modified_time_utc(rollout_path).await { + metadata.updated_at = updated_at; + } + if metadata.id != thread_id { + warn!( + "thread id mismatch during unarchive: expected {thread_id}, got {}", + metadata.id + ); + } + self.upsert_thread(&metadata).await + } + + /// Delete a thread metadata row by id. + pub async fn delete_thread(&self, thread_id: ThreadId) -> anyhow::Result { + let result = sqlx::query("DELETE FROM threads WHERE id = ?") + .bind(thread_id.to_string()) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected()) + } +} + +pub(super) fn extract_dynamic_tools(items: &[RolloutItem]) -> Option>> { + items.iter().find_map(|item| match item { + RolloutItem::SessionMeta(meta_line) => Some(meta_line.meta.dynamic_tools.clone()), + RolloutItem::ResponseItem(_) + | RolloutItem::Compacted(_) + | RolloutItem::TurnContext(_) + | RolloutItem::EventMsg(_) => None, + }) +} + +pub(super) fn push_thread_filters<'a>( + builder: &mut QueryBuilder<'a, Sqlite>, + archived_only: bool, + allowed_sources: &'a [String], + model_providers: Option<&'a [String]>, + anchor: Option<&crate::Anchor>, + sort_key: SortKey, + search_term: Option<&'a str>, +) { + builder.push(" WHERE 1 = 1"); + if archived_only { + builder.push(" AND archived = 1"); + } else { + builder.push(" AND archived = 0"); + } + builder.push(" AND first_user_message <> ''"); + if !allowed_sources.is_empty() { + builder.push(" AND source IN ("); + let mut separated = builder.separated(", "); + for source in allowed_sources { + separated.push_bind(source); + } + separated.push_unseparated(")"); + } + if let Some(model_providers) = model_providers + && !model_providers.is_empty() + { + builder.push(" AND model_provider IN ("); + let mut separated = builder.separated(", "); + for provider in model_providers { + separated.push_bind(provider); + } + separated.push_unseparated(")"); + } + if let Some(search_term) = search_term { + builder.push(" AND instr(title, "); + builder.push_bind(search_term); + builder.push(") > 0"); + } + if let Some(anchor) = anchor { + let anchor_ts = datetime_to_epoch_seconds(anchor.ts); + let column = match sort_key { + SortKey::CreatedAt => "created_at", + SortKey::UpdatedAt => "updated_at", + }; + builder.push(" AND ("); + builder.push(column); + builder.push(" < "); + builder.push_bind(anchor_ts); + builder.push(" OR ("); + builder.push(column); + builder.push(" = "); + builder.push_bind(anchor_ts); + builder.push(" AND id < "); + builder.push_bind(anchor.id.to_string()); + builder.push("))"); + } +} + +pub(super) fn push_thread_order_and_limit( + builder: &mut QueryBuilder<'_, Sqlite>, + sort_key: SortKey, + limit: usize, +) { + let order_column = match sort_key { + SortKey::CreatedAt => "created_at", + SortKey::UpdatedAt => "updated_at", + }; + builder.push(" ORDER BY "); + builder.push(order_column); + builder.push(" DESC, id DESC"); + builder.push(" LIMIT "); + builder.push_bind(limit as i64); +} diff --git a/codex-rs/tui/src/app.rs b/codex-rs/tui/src/app.rs index 6ed8670e075..e32a76da0fa 100644 --- a/codex-rs/tui/src/app.rs +++ b/codex-rs/tui/src/app.rs @@ -1,6 +1,7 @@ use crate::app_backtrack::BacktrackState; use crate::app_event::AppEvent; use crate::app_event::ExitMode; +use crate::app_event::RealtimeAudioDeviceKind; #[cfg(target_os = "windows")] use crate::app_event::WindowsSandboxEnableMode; use crate::app_event_sender::AppEventSender; @@ -1414,12 +1415,16 @@ impl App { }; ChatWidget::new(init, thread_manager.clone()) } - SessionSelection::Resume(path) => { + SessionSelection::Resume(target_session) => { let resumed = thread_manager - .resume_thread_from_rollout(config.clone(), path.clone(), auth_manager.clone()) + .resume_thread_from_rollout( + config.clone(), + target_session.path.clone(), + auth_manager.clone(), + ) .await .wrap_err_with(|| { - let path_display = path.display(); + let path_display = target_session.path.display(); format!("Failed to resume session from {path_display}") })?; let init = crate::chatwidget::ChatWidgetInit { @@ -1444,13 +1449,18 @@ impl App { }; ChatWidget::new_from_existing(init, resumed.thread, resumed.session_configured) } - SessionSelection::Fork(path) => { + SessionSelection::Fork(target_session) => { otel_manager.counter("codex.thread.fork", 1, &[("source", "cli_subcommand")]); let forked = thread_manager - .fork_thread(usize::MAX, config.clone(), path.clone(), false) + .fork_thread( + usize::MAX, + config.clone(), + target_session.path.clone(), + false, + ) .await .wrap_err_with(|| { - let path_display = path.display(); + let path_display = target_session.path.display(); format!("Failed to fork session from {path_display}") })?; let init = crate::chatwidget::ChatWidgetInit { @@ -1713,12 +1723,14 @@ impl App { } AppEvent::OpenResumePicker => { match crate::resume_picker::run_resume_picker(tui, &self.config, false).await? { - SessionSelection::Resume(path) => { + SessionSelection::Resume(target_session) => { let current_cwd = self.config.cwd.clone(); let resume_cwd = match crate::resolve_cwd_for_resume_or_fork( tui, + &self.config, ¤t_cwd, - &path, + target_session.thread_id, + &target_session.path, CwdPromptAction::Resume, true, ) @@ -1754,7 +1766,7 @@ impl App { .server .resume_thread_from_rollout( resume_config.clone(), - path.clone(), + target_session.path.clone(), self.auth_manager.clone(), ) .await @@ -1788,7 +1800,7 @@ impl App { } } Err(err) => { - let path_display = path.display(); + let path_display = target_session.path.display(); self.chat_widget.add_error_message(format!( "Failed to resume session from {path_display}: {err}" )); @@ -1925,19 +1937,9 @@ impl App { AppEvent::CodexEvent(event) => { self.enqueue_primary_event(event).await?; } - AppEvent::Exit(mode) => match mode { - ExitMode::ShutdownFirst => { - // Mark the thread we are explicitly shutting down for exit so - // its shutdown completion does not trigger agent failover. - self.pending_shutdown_exit_thread_id = - self.active_thread_id.or(self.chat_widget.thread_id()); - self.chat_widget.submit_op(Op::Shutdown); - } - ExitMode::Immediate => { - self.pending_shutdown_exit_thread_id = None; - return Ok(AppRunControl::Exit(ExitReason::UserRequested)); - } - }, + AppEvent::Exit(mode) => { + return Ok(self.handle_exit_mode(mode)); + } AppEvent::FatalExitRequest(message) => { return Ok(AppRunControl::Exit(ExitReason::Fatal(message))); } @@ -2019,6 +2021,9 @@ impl App { AppEvent::UpdatePersonality(personality) => { self.on_update_personality(personality); } + AppEvent::OpenRealtimeAudioDeviceSelection { kind } => { + self.chat_widget.open_realtime_audio_device_selection(kind); + } AppEvent::OpenReasoningPopup { model } => { self.chat_widget.open_reasoning_popup(model); } @@ -2444,6 +2449,56 @@ impl App { } } } + AppEvent::PersistRealtimeAudioDeviceSelection { kind, name } => { + let builder = match kind { + RealtimeAudioDeviceKind::Microphone => { + ConfigEditsBuilder::new(&self.config.codex_home) + .set_realtime_microphone(name.as_deref()) + } + RealtimeAudioDeviceKind::Speaker => { + ConfigEditsBuilder::new(&self.config.codex_home) + .set_realtime_speaker(name.as_deref()) + } + }; + + match builder.apply().await { + Ok(()) => { + match kind { + RealtimeAudioDeviceKind::Microphone => { + self.config.realtime_audio.microphone = name.clone(); + } + RealtimeAudioDeviceKind::Speaker => { + self.config.realtime_audio.speaker = name.clone(); + } + } + self.chat_widget + .set_realtime_audio_device(kind, name.clone()); + + if self.chat_widget.realtime_conversation_is_live() { + self.chat_widget.open_realtime_audio_restart_prompt(kind); + } else { + let selection = name.unwrap_or_else(|| "System default".to_string()); + self.chat_widget.add_info_message( + format!("Realtime {} set to {selection}", kind.noun()), + None, + ); + } + } + Err(err) => { + tracing::error!( + error = %err, + "failed to persist realtime audio selection" + ); + self.chat_widget.add_error_message(format!( + "Failed to save realtime {}: {err}", + kind.noun() + )); + } + } + } + AppEvent::RestartRealtimeAudioDevice { kind } => { + self.chat_widget.restart_realtime_audio_device(kind); + } AppEvent::UpdateAskForApprovalPolicy(policy) => { self.runtime_approval_policy_override = Some(policy); if let Err(err) = self.config.permissions.approval_policy.set(policy) { @@ -2914,6 +2969,27 @@ impl App { Ok(AppRunControl::Continue) } + fn handle_exit_mode(&mut self, mode: ExitMode) -> AppRunControl { + match mode { + ExitMode::ShutdownFirst => { + // Mark the thread we are explicitly shutting down for exit so + // its shutdown completion does not trigger agent failover. + self.pending_shutdown_exit_thread_id = + self.active_thread_id.or(self.chat_widget.thread_id()); + if self.chat_widget.submit_op(Op::Shutdown) { + AppRunControl::Continue + } else { + self.pending_shutdown_exit_thread_id = None; + AppRunControl::Exit(ExitReason::UserRequested) + } + } + ExitMode::Immediate => { + self.pending_shutdown_exit_thread_id = None; + AppRunControl::Exit(ExitReason::UserRequested) + } + } + } + fn handle_codex_event_now(&mut self, event: Event) { let needs_refresh = matches!( event.msg, @@ -3389,15 +3465,21 @@ mod tests { true ); assert_eq!( - App::should_wait_for_initial_session(&SessionSelection::Resume(PathBuf::from( - "/tmp/restore" - ))), + App::should_wait_for_initial_session(&SessionSelection::Resume( + crate::resume_picker::SessionTarget { + path: PathBuf::from("/tmp/restore"), + thread_id: ThreadId::new(), + } + )), false ); assert_eq!( - App::should_wait_for_initial_session(&SessionSelection::Fork(PathBuf::from( - "/tmp/fork" - ))), + App::should_wait_for_initial_session(&SessionSelection::Fork( + crate::resume_picker::SessionTarget { + path: PathBuf::from("/tmp/fork"), + thread_id: ThreadId::new(), + } + )), false ); } @@ -3433,14 +3515,20 @@ mod tests { #[test] fn startup_waiting_gate_not_applied_for_resume_or_fork_session_selection() { let wait_for_resume = App::should_wait_for_initial_session(&SessionSelection::Resume( - PathBuf::from("/tmp/restore"), + crate::resume_picker::SessionTarget { + path: PathBuf::from("/tmp/restore"), + thread_id: ThreadId::new(), + }, )); assert_eq!( App::should_handle_active_thread_events(wait_for_resume, true), true ); let wait_for_fork = App::should_wait_for_initial_session(&SessionSelection::Fork( - PathBuf::from("/tmp/fork"), + crate::resume_picker::SessionTarget { + path: PathBuf::from("/tmp/fork"), + thread_id: ThreadId::new(), + }, )); assert_eq!( App::should_handle_active_thread_events(wait_for_fork, true), @@ -4703,6 +4791,34 @@ mod tests { } } + #[tokio::test] + async fn shutdown_first_exit_returns_immediate_exit_when_shutdown_submit_fails() { + let mut app = make_test_app().await; + let thread_id = ThreadId::new(); + app.active_thread_id = Some(thread_id); + + let control = app.handle_exit_mode(ExitMode::ShutdownFirst); + + assert_eq!(app.pending_shutdown_exit_thread_id, None); + assert!(matches!( + control, + AppRunControl::Exit(ExitReason::UserRequested) + )); + } + + #[tokio::test] + async fn shutdown_first_exit_waits_for_shutdown_when_submit_succeeds() { + let (mut app, _app_event_rx, mut op_rx) = make_test_app_with_channels().await; + let thread_id = ThreadId::new(); + app.active_thread_id = Some(thread_id); + + let control = app.handle_exit_mode(ExitMode::ShutdownFirst); + + assert_eq!(app.pending_shutdown_exit_thread_id, Some(thread_id)); + assert!(matches!(control, AppRunControl::Continue)); + assert_eq!(op_rx.try_recv(), Ok(Op::Shutdown)); + } + #[tokio::test] async fn clear_only_ui_reset_preserves_chat_session_state() { let mut app = make_test_app().await; diff --git a/codex-rs/tui/src/app_event.rs b/codex-rs/tui/src/app_event.rs index 0767696e282..104895ad9e6 100644 --- a/codex-rs/tui/src/app_event.rs +++ b/codex-rs/tui/src/app_event.rs @@ -29,6 +29,28 @@ use codex_protocol::openai_models::ReasoningEffort; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::SandboxPolicy; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum RealtimeAudioDeviceKind { + Microphone, + Speaker, +} + +impl RealtimeAudioDeviceKind { + pub(crate) fn title(self) -> &'static str { + match self { + Self::Microphone => "Microphone", + Self::Speaker => "Speaker", + } + } + + pub(crate) fn noun(self) -> &'static str { + match self { + Self::Microphone => "microphone", + Self::Speaker => "speaker", + } + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[cfg_attr(not(target_os = "windows"), allow(dead_code))] pub(crate) enum WindowsSandboxEnableMode { @@ -166,6 +188,26 @@ pub(crate) enum AppEvent { personality: Personality, }, + /// Open the device picker for a realtime microphone or speaker. + OpenRealtimeAudioDeviceSelection { + kind: RealtimeAudioDeviceKind, + }, + + /// Persist the selected realtime microphone or speaker to top-level config. + #[cfg_attr( + any(target_os = "linux", not(feature = "voice-input")), + allow(dead_code) + )] + PersistRealtimeAudioDeviceSelection { + kind: RealtimeAudioDeviceKind, + name: Option, + }, + + /// Restart the selected realtime microphone or speaker locally. + RestartRealtimeAudioDevice { + kind: RealtimeAudioDeviceKind, + }, + /// Open the reasoning selection popup after picking a model. OpenReasoningPopup { model: ModelPreset, diff --git a/codex-rs/tui/src/audio_device.rs b/codex-rs/tui/src/audio_device.rs new file mode 100644 index 00000000000..467693e7fe7 --- /dev/null +++ b/codex-rs/tui/src/audio_device.rs @@ -0,0 +1,129 @@ +use codex_core::config::Config; +use cpal::traits::DeviceTrait; +use cpal::traits::HostTrait; +use tracing::warn; + +use crate::app_event::RealtimeAudioDeviceKind; + +pub(crate) fn list_realtime_audio_device_names( + kind: RealtimeAudioDeviceKind, +) -> Result, String> { + let host = cpal::default_host(); + let mut device_names = Vec::new(); + for device in devices(&host, kind)? { + let Ok(name) = device.name() else { + continue; + }; + if !device_names.contains(&name) { + device_names.push(name); + } + } + Ok(device_names) +} + +pub(crate) fn select_configured_input_device_and_config( + config: &Config, +) -> Result<(cpal::Device, cpal::SupportedStreamConfig), String> { + select_device_and_config(RealtimeAudioDeviceKind::Microphone, config) +} + +pub(crate) fn select_configured_output_device_and_config( + config: &Config, +) -> Result<(cpal::Device, cpal::SupportedStreamConfig), String> { + select_device_and_config(RealtimeAudioDeviceKind::Speaker, config) +} + +fn select_device_and_config( + kind: RealtimeAudioDeviceKind, + config: &Config, +) -> Result<(cpal::Device, cpal::SupportedStreamConfig), String> { + let host = cpal::default_host(); + let configured_name = configured_name(kind, config); + let selected = configured_name + .and_then(|name| find_device_by_name(&host, kind, name)) + .or_else(|| { + let default_device = default_device(&host, kind); + if let Some(name) = configured_name && default_device.is_some() { + warn!( + "configured {} audio device `{name}` was unavailable; falling back to system default", + kind.noun() + ); + } + default_device + }) + .ok_or_else(|| missing_device_error(kind, configured_name))?; + + let stream_config = default_config(&selected, kind)?; + Ok((selected, stream_config)) +} + +fn configured_name(kind: RealtimeAudioDeviceKind, config: &Config) -> Option<&str> { + match kind { + RealtimeAudioDeviceKind::Microphone => config.realtime_audio.microphone.as_deref(), + RealtimeAudioDeviceKind::Speaker => config.realtime_audio.speaker.as_deref(), + } +} + +fn find_device_by_name( + host: &cpal::Host, + kind: RealtimeAudioDeviceKind, + name: &str, +) -> Option { + let devices = devices(host, kind).ok()?; + devices + .into_iter() + .find(|device| device.name().ok().as_deref() == Some(name)) +} + +fn devices(host: &cpal::Host, kind: RealtimeAudioDeviceKind) -> Result, String> { + match kind { + RealtimeAudioDeviceKind::Microphone => host + .input_devices() + .map(|devices| devices.collect()) + .map_err(|err| format!("failed to enumerate input audio devices: {err}")), + RealtimeAudioDeviceKind::Speaker => host + .output_devices() + .map(|devices| devices.collect()) + .map_err(|err| format!("failed to enumerate output audio devices: {err}")), + } +} + +fn default_device(host: &cpal::Host, kind: RealtimeAudioDeviceKind) -> Option { + match kind { + RealtimeAudioDeviceKind::Microphone => host.default_input_device(), + RealtimeAudioDeviceKind::Speaker => host.default_output_device(), + } +} + +fn default_config( + device: &cpal::Device, + kind: RealtimeAudioDeviceKind, +) -> Result { + match kind { + RealtimeAudioDeviceKind::Microphone => device + .default_input_config() + .map_err(|err| format!("failed to get default input config: {err}")), + RealtimeAudioDeviceKind::Speaker => device + .default_output_config() + .map_err(|err| format!("failed to get default output config: {err}")), + } +} + +fn missing_device_error(kind: RealtimeAudioDeviceKind, configured_name: Option<&str>) -> String { + match (kind, configured_name) { + (RealtimeAudioDeviceKind::Microphone, Some(name)) => { + format!( + "configured microphone `{name}` was unavailable and no default input audio device was found" + ) + } + (RealtimeAudioDeviceKind::Speaker, Some(name)) => { + format!( + "configured speaker `{name}` was unavailable and no default output audio device was found" + ) + } + (RealtimeAudioDeviceKind::Microphone, None) => { + "no input audio device available".to_string() + } + (RealtimeAudioDeviceKind::Speaker, None) => "no output audio device available".to_string(), + } +} diff --git a/codex-rs/tui/src/bottom_pane/chat_composer.rs b/codex-rs/tui/src/bottom_pane/chat_composer.rs index 9943a22531d..fe40f4fa7bf 100644 --- a/codex-rs/tui/src/bottom_pane/chat_composer.rs +++ b/codex-rs/tui/src/bottom_pane/chat_composer.rs @@ -395,6 +395,7 @@ pub(crate) struct ChatComposer { connectors_enabled: bool, personality_command_enabled: bool, realtime_conversation_enabled: bool, + audio_device_selection_enabled: bool, windows_degraded_sandbox_active: bool, status_line_value: Option>, status_line_enabled: bool, @@ -500,6 +501,7 @@ impl ChatComposer { connectors_enabled: false, personality_command_enabled: false, realtime_conversation_enabled: false, + audio_device_selection_enabled: false, windows_degraded_sandbox_active: false, status_line_value: None, status_line_enabled: false, @@ -577,10 +579,13 @@ impl ChatComposer { self.realtime_conversation_enabled = enabled; } + pub fn set_audio_device_selection_enabled(&mut self, enabled: bool) { + self.audio_device_selection_enabled = enabled; + } + /// Compatibility shim for tests that still toggle the removed steer mode flag. #[cfg(test)] pub fn set_steer_enabled(&mut self, _enabled: bool) {} - pub fn set_voice_transcription_enabled(&mut self, enabled: bool) { self.voice_state.transcription_enabled = enabled; if !enabled { @@ -2264,6 +2269,7 @@ impl ChatComposer { self.connectors_enabled, self.personality_command_enabled, self.realtime_conversation_enabled, + self.audio_device_selection_enabled, self.windows_degraded_sandbox_active, ) .is_some(); @@ -2480,6 +2486,7 @@ impl ChatComposer { self.connectors_enabled, self.personality_command_enabled, self.realtime_conversation_enabled, + self.audio_device_selection_enabled, self.windows_degraded_sandbox_active, ) { @@ -2515,6 +2522,7 @@ impl ChatComposer { self.connectors_enabled, self.personality_command_enabled, self.realtime_conversation_enabled, + self.audio_device_selection_enabled, self.windows_degraded_sandbox_active, )?; @@ -3334,6 +3342,7 @@ impl ChatComposer { self.connectors_enabled, self.personality_command_enabled, self.realtime_conversation_enabled, + self.audio_device_selection_enabled, self.windows_degraded_sandbox_active, ) .is_some(); @@ -3396,6 +3405,7 @@ impl ChatComposer { self.connectors_enabled, self.personality_command_enabled, self.realtime_conversation_enabled, + self.audio_device_selection_enabled, self.windows_degraded_sandbox_active, ) { return true; @@ -3450,6 +3460,7 @@ impl ChatComposer { let connectors_enabled = self.connectors_enabled; let personality_command_enabled = self.personality_command_enabled; let realtime_conversation_enabled = self.realtime_conversation_enabled; + let audio_device_selection_enabled = self.audio_device_selection_enabled; let mut command_popup = CommandPopup::new( self.custom_prompts.clone(), CommandPopupFlags { @@ -3457,6 +3468,7 @@ impl ChatComposer { connectors_enabled, personality_command_enabled, realtime_conversation_enabled, + audio_device_selection_enabled, windows_degraded_sandbox_active: self.windows_degraded_sandbox_active, }, ); diff --git a/codex-rs/tui/src/bottom_pane/command_popup.rs b/codex-rs/tui/src/bottom_pane/command_popup.rs index 83b523cc234..62765d9c2c8 100644 --- a/codex-rs/tui/src/bottom_pane/command_popup.rs +++ b/codex-rs/tui/src/bottom_pane/command_popup.rs @@ -40,6 +40,7 @@ pub(crate) struct CommandPopupFlags { pub(crate) connectors_enabled: bool, pub(crate) personality_command_enabled: bool, pub(crate) realtime_conversation_enabled: bool, + pub(crate) audio_device_selection_enabled: bool, pub(crate) windows_degraded_sandbox_active: bool, } @@ -51,6 +52,7 @@ impl CommandPopup { flags.connectors_enabled, flags.personality_command_enabled, flags.realtime_conversation_enabled, + flags.audio_device_selection_enabled, flags.windows_degraded_sandbox_active, ) .into_iter() @@ -498,6 +500,7 @@ mod tests { connectors_enabled: false, personality_command_enabled: true, realtime_conversation_enabled: false, + audio_device_selection_enabled: false, windows_degraded_sandbox_active: false, }, ); @@ -518,6 +521,7 @@ mod tests { connectors_enabled: false, personality_command_enabled: true, realtime_conversation_enabled: false, + audio_device_selection_enabled: false, windows_degraded_sandbox_active: false, }, ); @@ -538,6 +542,7 @@ mod tests { connectors_enabled: false, personality_command_enabled: false, realtime_conversation_enabled: false, + audio_device_selection_enabled: false, windows_degraded_sandbox_active: false, }, ); @@ -566,6 +571,7 @@ mod tests { connectors_enabled: false, personality_command_enabled: true, realtime_conversation_enabled: false, + audio_device_selection_enabled: false, windows_degraded_sandbox_active: false, }, ); @@ -577,6 +583,36 @@ mod tests { } } + #[test] + fn settings_command_hidden_when_audio_device_selection_is_disabled() { + let mut popup = CommandPopup::new( + Vec::new(), + CommandPopupFlags { + collaboration_modes_enabled: false, + connectors_enabled: false, + personality_command_enabled: true, + realtime_conversation_enabled: true, + audio_device_selection_enabled: false, + windows_degraded_sandbox_active: false, + }, + ); + popup.on_composer_text_change("/aud".to_string()); + + let cmds: Vec<&str> = popup + .filtered_items() + .into_iter() + .filter_map(|item| match item { + CommandItem::Builtin(cmd) => Some(cmd.command()), + CommandItem::UserPrompt(_) => None, + }) + .collect(); + + assert!( + !cmds.contains(&"settings"), + "expected '/settings' to be hidden when audio device selection is disabled, got {cmds:?}" + ); + } + #[test] fn debug_commands_are_hidden_from_popup() { let popup = CommandPopup::new(Vec::new(), CommandPopupFlags::default()); diff --git a/codex-rs/tui/src/bottom_pane/mod.rs b/codex-rs/tui/src/bottom_pane/mod.rs index e774e9c4200..dac00698871 100644 --- a/codex-rs/tui/src/bottom_pane/mod.rs +++ b/codex-rs/tui/src/bottom_pane/mod.rs @@ -298,6 +298,11 @@ impl BottomPane { self.request_redraw(); } + pub fn set_audio_device_selection_enabled(&mut self, enabled: bool) { + self.composer.set_audio_device_selection_enabled(enabled); + self.request_redraw(); + } + pub fn set_voice_transcription_enabled(&mut self, enabled: bool) { self.composer.set_voice_transcription_enabled(enabled); self.request_redraw(); diff --git a/codex-rs/tui/src/bottom_pane/slash_commands.rs b/codex-rs/tui/src/bottom_pane/slash_commands.rs index 86c131e6db1..981c61c451b 100644 --- a/codex-rs/tui/src/bottom_pane/slash_commands.rs +++ b/codex-rs/tui/src/bottom_pane/slash_commands.rs @@ -14,6 +14,7 @@ pub(crate) fn builtins_for_input( connectors_enabled: bool, personality_command_enabled: bool, realtime_conversation_enabled: bool, + audio_device_selection_enabled: bool, allow_elevate_sandbox: bool, ) -> Vec<(&'static str, SlashCommand)> { built_in_slash_commands() @@ -26,6 +27,7 @@ pub(crate) fn builtins_for_input( .filter(|(_, cmd)| connectors_enabled || *cmd != SlashCommand::Apps) .filter(|(_, cmd)| personality_command_enabled || *cmd != SlashCommand::Personality) .filter(|(_, cmd)| realtime_conversation_enabled || *cmd != SlashCommand::Realtime) + .filter(|(_, cmd)| audio_device_selection_enabled || *cmd != SlashCommand::Settings) .collect() } @@ -36,6 +38,7 @@ pub(crate) fn find_builtin_command( connectors_enabled: bool, personality_command_enabled: bool, realtime_conversation_enabled: bool, + audio_device_selection_enabled: bool, allow_elevate_sandbox: bool, ) -> Option { builtins_for_input( @@ -43,6 +46,7 @@ pub(crate) fn find_builtin_command( connectors_enabled, personality_command_enabled, realtime_conversation_enabled, + audio_device_selection_enabled, allow_elevate_sandbox, ) .into_iter() @@ -57,6 +61,7 @@ pub(crate) fn has_builtin_prefix( connectors_enabled: bool, personality_command_enabled: bool, realtime_conversation_enabled: bool, + audio_device_selection_enabled: bool, allow_elevate_sandbox: bool, ) -> bool { builtins_for_input( @@ -64,6 +69,7 @@ pub(crate) fn has_builtin_prefix( connectors_enabled, personality_command_enabled, realtime_conversation_enabled, + audio_device_selection_enabled, allow_elevate_sandbox, ) .into_iter() @@ -77,14 +83,14 @@ mod tests { #[test] fn debug_command_still_resolves_for_dispatch() { - let cmd = find_builtin_command("debug-config", true, true, true, false, false); + let cmd = find_builtin_command("debug-config", true, true, true, false, false, false); assert_eq!(cmd, Some(SlashCommand::DebugConfig)); } #[test] fn clear_command_resolves_for_dispatch() { assert_eq!( - find_builtin_command("clear", true, true, true, false, false), + find_builtin_command("clear", true, true, true, false, false, false), Some(SlashCommand::Clear) ); } @@ -92,7 +98,23 @@ mod tests { #[test] fn realtime_command_is_hidden_when_realtime_is_disabled() { assert_eq!( - find_builtin_command("realtime", true, true, true, false, false), + find_builtin_command("realtime", true, true, true, false, true, false), + None + ); + } + + #[test] + fn settings_command_is_hidden_when_realtime_is_disabled() { + assert_eq!( + find_builtin_command("settings", true, true, true, false, false, false), + None + ); + } + + #[test] + fn settings_command_is_hidden_when_audio_device_selection_is_disabled() { + assert_eq!( + find_builtin_command("settings", true, true, true, true, false, false), None ); } diff --git a/codex-rs/tui/src/chatwidget.rs b/codex-rs/tui/src/chatwidget.rs index a4063ed3d6b..4031cecaf63 100644 --- a/codex-rs/tui/src/chatwidget.rs +++ b/codex-rs/tui/src/chatwidget.rs @@ -37,6 +37,9 @@ use std::sync::atomic::Ordering; use std::time::Duration; use std::time::Instant; +use crate::app_event::RealtimeAudioDeviceKind; +#[cfg(all(not(target_os = "linux"), feature = "voice-input"))] +use crate::audio_device::list_realtime_audio_device_names; use crate::bottom_pane::StatusLineItem; use crate::bottom_pane::StatusLineSetupView; use crate::status::RateLimitWindowDisplay; @@ -855,6 +858,10 @@ impl ChatWidget { && cfg!(not(target_os = "linux")) } + fn realtime_audio_device_selection_enabled(&self) -> bool { + self.realtime_conversation_enabled() && cfg!(feature = "voice-input") + } + /// Synchronize the bottom-pane "task running" indicator with the current lifecycles. /// /// The bottom pane only has one running flag, but this module treats it as a derived state of @@ -2882,6 +2889,9 @@ impl ChatWidget { widget .bottom_pane .set_realtime_conversation_enabled(widget.realtime_conversation_enabled()); + widget + .bottom_pane + .set_audio_device_selection_enabled(widget.realtime_audio_device_selection_enabled()); widget .bottom_pane .set_status_line_enabled(!widget.configured_status_line_items().is_empty()); @@ -3056,6 +3066,9 @@ impl ChatWidget { widget .bottom_pane .set_realtime_conversation_enabled(widget.realtime_conversation_enabled()); + widget + .bottom_pane + .set_audio_device_selection_enabled(widget.realtime_audio_device_selection_enabled()); widget .bottom_pane .set_status_line_enabled(!widget.configured_status_line_items().is_empty()); @@ -3219,6 +3232,9 @@ impl ChatWidget { widget .bottom_pane .set_realtime_conversation_enabled(widget.realtime_conversation_enabled()); + widget + .bottom_pane + .set_audio_device_selection_enabled(widget.realtime_audio_device_selection_enabled()); widget .bottom_pane .set_status_line_enabled(!widget.configured_status_line_items().is_empty()); @@ -3531,6 +3547,12 @@ impl ChatWidget { self.start_realtime_conversation(); } } + SlashCommand::Settings => { + if !self.realtime_audio_device_selection_enabled() { + return; + } + self.open_realtime_audio_popup(); + } SlashCommand::Personality => { self.open_personality_popup(); } @@ -4139,7 +4161,7 @@ impl ChatWidget { sandbox_policy: self.config.permissions.sandbox_policy.get().clone(), model: effective_mode.model().to_string(), effort: effective_mode.reasoning_effort(), - summary: self.config.model_reasoning_summary, + summary: None, final_output_json_schema: None, collaboration_mode, personality, @@ -5270,6 +5292,161 @@ impl ChatWidget { }); } + pub(crate) fn open_realtime_audio_popup(&mut self) { + let items = [ + RealtimeAudioDeviceKind::Microphone, + RealtimeAudioDeviceKind::Speaker, + ] + .into_iter() + .map(|kind| { + let description = Some(format!( + "Current: {}", + self.current_realtime_audio_selection_label(kind) + )); + let actions: Vec = vec![Box::new(move |tx| { + tx.send(AppEvent::OpenRealtimeAudioDeviceSelection { kind }); + })]; + SelectionItem { + name: kind.title().to_string(), + description, + actions, + dismiss_on_select: true, + ..Default::default() + } + }) + .collect(); + + self.bottom_pane.show_selection_view(SelectionViewParams { + title: Some("Settings".to_string()), + subtitle: Some("Configure settings for Codex.".to_string()), + footer_hint: Some(standard_popup_hint_line()), + items, + ..Default::default() + }); + } + + #[cfg(all(not(target_os = "linux"), feature = "voice-input"))] + pub(crate) fn open_realtime_audio_device_selection(&mut self, kind: RealtimeAudioDeviceKind) { + match list_realtime_audio_device_names(kind) { + Ok(device_names) => { + self.open_realtime_audio_device_selection_with_names(kind, device_names); + } + Err(err) => { + self.add_error_message(format!( + "Failed to load realtime {} devices: {err}", + kind.noun() + )); + } + } + } + + #[cfg(any(target_os = "linux", not(feature = "voice-input")))] + pub(crate) fn open_realtime_audio_device_selection(&mut self, kind: RealtimeAudioDeviceKind) { + let _ = kind; + } + + #[cfg(all(not(target_os = "linux"), feature = "voice-input"))] + fn open_realtime_audio_device_selection_with_names( + &mut self, + kind: RealtimeAudioDeviceKind, + device_names: Vec, + ) { + let current_selection = self.current_realtime_audio_device_name(kind); + let current_available = current_selection + .as_deref() + .is_some_and(|name| device_names.iter().any(|device_name| device_name == name)); + let mut items = vec![SelectionItem { + name: "System default".to_string(), + description: Some("Use your operating system default device.".to_string()), + is_current: current_selection.is_none(), + actions: vec![Box::new(move |tx| { + tx.send(AppEvent::PersistRealtimeAudioDeviceSelection { kind, name: None }); + })], + dismiss_on_select: true, + ..Default::default() + }]; + + if let Some(selection) = current_selection.as_deref() + && !current_available + { + items.push(SelectionItem { + name: format!("Unavailable: {selection}"), + description: Some("Configured device is not currently available.".to_string()), + is_current: true, + is_disabled: true, + disabled_reason: Some("Reconnect the device or choose another one.".to_string()), + ..Default::default() + }); + } + + items.extend(device_names.into_iter().map(|device_name| { + let persisted_name = device_name.clone(); + let actions: Vec = vec![Box::new(move |tx| { + tx.send(AppEvent::PersistRealtimeAudioDeviceSelection { + kind, + name: Some(persisted_name.clone()), + }); + })]; + SelectionItem { + is_current: current_selection.as_deref() == Some(device_name.as_str()), + name: device_name, + actions, + dismiss_on_select: true, + ..Default::default() + } + })); + + let mut header = ColumnRenderable::new(); + header.push(Line::from(format!("Select {}", kind.title()).bold())); + header.push(Line::from( + "Saved devices apply to realtime voice only.".dim(), + )); + + self.bottom_pane.show_selection_view(SelectionViewParams { + header: Box::new(header), + footer_hint: Some(standard_popup_hint_line()), + items, + ..Default::default() + }); + } + + pub(crate) fn open_realtime_audio_restart_prompt(&mut self, kind: RealtimeAudioDeviceKind) { + let restart_actions: Vec = vec![Box::new(move |tx| { + tx.send(AppEvent::RestartRealtimeAudioDevice { kind }); + })]; + let items = vec![ + SelectionItem { + name: "Restart now".to_string(), + description: Some(format!("Restart local {} audio now.", kind.noun())), + actions: restart_actions, + dismiss_on_select: true, + ..Default::default() + }, + SelectionItem { + name: "Apply later".to_string(), + description: Some(format!( + "Keep the current {} until local audio starts again.", + kind.noun() + )), + dismiss_on_select: true, + ..Default::default() + }, + ]; + + let mut header = ColumnRenderable::new(); + header.push(Line::from(format!("Restart {} now?", kind.title()).bold())); + header.push(Line::from( + "Configuration is saved. Restart local audio to use it immediately.".dim(), + )); + + self.bottom_pane.show_selection_view(SelectionViewParams { + header: Box::new(header), + footer_hint: Some(standard_popup_hint_line()), + items, + ..Default::default() + }); + } + fn model_menu_header(&self, title: &str, subtitle: &str) -> Box { let title = title.to_string(); let subtitle = subtitle.to_string(); @@ -6523,6 +6700,8 @@ impl ChatWidget { let realtime_conversation_enabled = self.realtime_conversation_enabled(); self.bottom_pane .set_realtime_conversation_enabled(realtime_conversation_enabled); + self.bottom_pane + .set_audio_device_selection_enabled(self.realtime_audio_device_selection_enabled()); if !realtime_conversation_enabled && self.realtime_conversation.is_live() { self.request_realtime_conversation_close(Some( "Realtime voice mode was closed because the feature was disabled.".to_string(), @@ -6612,6 +6791,17 @@ impl ChatWidget { self.config.personality = Some(personality); } + pub(crate) fn set_realtime_audio_device( + &mut self, + kind: RealtimeAudioDeviceKind, + name: Option, + ) { + match kind { + RealtimeAudioDeviceKind::Microphone => self.config.realtime_audio.microphone = name, + RealtimeAudioDeviceKind::Speaker => self.config.realtime_audio.speaker = name, + } + } + /// Set the syntax theme override in the widget's config copy. pub(crate) fn set_tui_theme(&mut self, theme: Option) { self.config.tui_theme = theme; @@ -6640,6 +6830,22 @@ impl ChatWidget { .unwrap_or_else(|| self.current_collaboration_mode.model()) } + pub(crate) fn realtime_conversation_is_live(&self) -> bool { + self.realtime_conversation.is_active() + } + + fn current_realtime_audio_device_name(&self, kind: RealtimeAudioDeviceKind) -> Option { + match kind { + RealtimeAudioDeviceKind::Microphone => self.config.realtime_audio.microphone.clone(), + RealtimeAudioDeviceKind::Speaker => self.config.realtime_audio.speaker.clone(), + } + } + + fn current_realtime_audio_selection_label(&self, kind: RealtimeAudioDeviceKind) -> String { + self.current_realtime_audio_device_name(kind) + .unwrap_or_else(|| "System default".to_string()) + } + fn sync_personality_command_enabled(&mut self) { self.bottom_pane .set_personality_command_enabled(self.config.features.enabled(Feature::Personality)); diff --git a/codex-rs/tui/src/chatwidget/agent.rs b/codex-rs/tui/src/chatwidget/agent.rs index 63d519a7a59..e14a9e3628a 100644 --- a/codex-rs/tui/src/chatwidget/agent.rs +++ b/codex-rs/tui/src/chatwidget/agent.rs @@ -13,6 +13,17 @@ use tokio::sync::mpsc::unbounded_channel; use crate::app_event::AppEvent; use crate::app_event_sender::AppEventSender; +const TUI_NOTIFY_CLIENT: &str = "codex-tui"; + +async fn initialize_app_server_client_name(thread: &CodexThread) { + if let Err(err) = thread + .set_app_server_client_name(Some(TUI_NOTIFY_CLIENT.to_string())) + .await + { + tracing::error!("failed to set app server client name: {err}"); + } +} + /// Spawn the agent bootstrapper and op forwarding loop, returning the /// `UnboundedSender` used by the UI to submit operations. pub(crate) fn spawn_agent( @@ -42,6 +53,7 @@ pub(crate) fn spawn_agent( return; } }; + initialize_app_server_client_name(thread.as_ref()).await; // Forward the captured `SessionConfigured` event so it can be rendered in the UI. let ev = codex_protocol::protocol::Event { @@ -87,6 +99,8 @@ pub(crate) fn spawn_agent_from_existing( let app_event_tx_clone = app_event_tx; tokio::spawn(async move { + initialize_app_server_client_name(thread.as_ref()).await; + // Forward the captured `SessionConfigured` event so it can be rendered in the UI. let ev = codex_protocol::protocol::Event { id: "".to_string(), @@ -123,6 +137,7 @@ pub(crate) fn spawn_op_forwarder(thread: std::sync::Arc) -> Unbound let (codex_op_tx, mut codex_op_rx) = unbounded_channel::(); tokio::spawn(async move { + initialize_app_server_client_name(thread.as_ref()).await; while let Some(op) = codex_op_rx.recv().await { if let Err(e) = thread.submit(op).await { tracing::error!("failed to submit op: {e}"); diff --git a/codex-rs/tui/src/chatwidget/realtime.rs b/codex-rs/tui/src/chatwidget/realtime.rs index 2cf97188b34..98ac76a34fc 100644 --- a/codex-rs/tui/src/chatwidget/realtime.rs +++ b/codex-rs/tui/src/chatwidget/realtime.rs @@ -41,6 +41,10 @@ impl RealtimeConversationUiState { | RealtimeConversationPhase::Stopping ) } + + pub(super) fn is_active(&self) -> bool { + matches!(self.phase, RealtimeConversationPhase::Active) + } } #[derive(Clone, Debug, PartialEq)] @@ -207,7 +211,7 @@ impl ChatWidget { { if self.realtime_conversation.audio_player.is_none() { self.realtime_conversation.audio_player = - crate::voice::RealtimeAudioPlayer::start().ok(); + crate::voice::RealtimeAudioPlayer::start(&self.config).ok(); } if let Some(player) = &self.realtime_conversation.audio_player && let Err(err) = player.enqueue_frame(frame) @@ -231,7 +235,10 @@ impl ChatWidget { self.realtime_conversation.meter_placeholder_id = Some(placeholder_id.clone()); self.request_redraw(); - let capture = match crate::voice::VoiceCapture::start_realtime(self.app_event_tx.clone()) { + let capture = match crate::voice::VoiceCapture::start_realtime( + &self.config, + self.app_event_tx.clone(), + ) { Ok(capture) => capture, Err(err) => { self.remove_transcription_placeholder(&placeholder_id); @@ -250,7 +257,7 @@ impl ChatWidget { self.realtime_conversation.capture = Some(capture); if self.realtime_conversation.audio_player.is_none() { self.realtime_conversation.audio_player = - crate::voice::RealtimeAudioPlayer::start().ok(); + crate::voice::RealtimeAudioPlayer::start(&self.config).ok(); } std::thread::spawn(move || { @@ -275,8 +282,50 @@ impl ChatWidget { #[cfg(target_os = "linux")] fn start_realtime_local_audio(&mut self) {} + #[cfg(all(not(target_os = "linux"), feature = "voice-input"))] + pub(crate) fn restart_realtime_audio_device(&mut self, kind: RealtimeAudioDeviceKind) { + if !self.realtime_conversation.is_active() { + return; + } + + match kind { + RealtimeAudioDeviceKind::Microphone => { + self.stop_realtime_microphone(); + self.start_realtime_local_audio(); + } + RealtimeAudioDeviceKind::Speaker => { + self.stop_realtime_speaker(); + match crate::voice::RealtimeAudioPlayer::start(&self.config) { + Ok(player) => { + self.realtime_conversation.audio_player = Some(player); + } + Err(err) => { + self.add_error_message(format!("Failed to start speaker output: {err}")); + } + } + } + } + self.request_redraw(); + } + + #[cfg(any(target_os = "linux", not(feature = "voice-input")))] + pub(crate) fn restart_realtime_audio_device(&mut self, kind: RealtimeAudioDeviceKind) { + let _ = kind; + } + #[cfg(not(target_os = "linux"))] fn stop_realtime_local_audio(&mut self) { + self.stop_realtime_microphone(); + self.stop_realtime_speaker(); + } + + #[cfg(target_os = "linux")] + fn stop_realtime_local_audio(&mut self) { + self.realtime_conversation.meter_placeholder_id = None; + } + + #[cfg(not(target_os = "linux"))] + fn stop_realtime_microphone(&mut self) { if let Some(flag) = self.realtime_conversation.capture_stop_flag.take() { flag.store(true, Ordering::Relaxed); } @@ -286,13 +335,12 @@ impl ChatWidget { if let Some(id) = self.realtime_conversation.meter_placeholder_id.take() { self.remove_transcription_placeholder(&id); } + } + + #[cfg(not(target_os = "linux"))] + fn stop_realtime_speaker(&mut self) { if let Some(player) = self.realtime_conversation.audio_player.take() { player.clear(); } } - - #[cfg(target_os = "linux")] - fn stop_realtime_local_audio(&mut self) { - self.realtime_conversation.meter_placeholder_id = None; - } } diff --git a/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__realtime_audio_selection_popup.snap b/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__realtime_audio_selection_popup.snap new file mode 100644 index 00000000000..8c60f961f9c --- /dev/null +++ b/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__realtime_audio_selection_popup.snap @@ -0,0 +1,11 @@ +--- +source: tui/src/chatwidget/tests.rs +expression: popup +--- + Settings + Configure settings for Codex. + +› 1. Microphone Current: System default + 2. Speaker Current: System default + + Press enter to confirm or esc to go back diff --git a/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__realtime_audio_selection_popup_narrow.snap b/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__realtime_audio_selection_popup_narrow.snap new file mode 100644 index 00000000000..8c60f961f9c --- /dev/null +++ b/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__realtime_audio_selection_popup_narrow.snap @@ -0,0 +1,11 @@ +--- +source: tui/src/chatwidget/tests.rs +expression: popup +--- + Settings + Configure settings for Codex. + +› 1. Microphone Current: System default + 2. Speaker Current: System default + + Press enter to confirm or esc to go back diff --git a/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__realtime_microphone_picker_popup.snap b/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__realtime_microphone_picker_popup.snap new file mode 100644 index 00000000000..3095e6da976 --- /dev/null +++ b/codex-rs/tui/src/chatwidget/snapshots/codex_tui__chatwidget__tests__realtime_microphone_picker_popup.snap @@ -0,0 +1,18 @@ +--- +source: tui/src/chatwidget/tests.rs +expression: popup +--- + Select Microphone + Saved devices apply to realtime voice only. + + 1. System default Use your operating system + default device. +› 2. Unavailable: Studio Mic (current) (disabled) Configured device is not + currently available. + (disabled: Reconnect the + device or choose another + one.) + 3. Built-in Mic + 4. USB Mic + + Press enter to confirm or esc to go back diff --git a/codex-rs/tui/src/chatwidget/tests.rs b/codex-rs/tui/src/chatwidget/tests.rs index 48de3e36813..b0ba5200a2a 100644 --- a/codex-rs/tui/src/chatwidget/tests.rs +++ b/codex-rs/tui/src/chatwidget/tests.rs @@ -7,6 +7,8 @@ use super::*; use crate::app_event::AppEvent; use crate::app_event::ExitMode; +#[cfg(all(not(target_os = "linux"), feature = "voice-input"))] +use crate::app_event::RealtimeAudioDeviceKind; use crate::app_event_sender::AppEventSender; use crate::bottom_pane::FeedbackAudience; use crate::bottom_pane::LocalImageAttachment; @@ -751,8 +753,8 @@ async fn enter_with_only_remote_images_submits_user_turn() { chat.handle_key_event(KeyEvent::new(KeyCode::Enter, KeyModifiers::NONE)); - let items = match next_submit_op(&mut op_rx) { - Op::UserTurn { items, .. } => items, + let (items, summary) = match next_submit_op(&mut op_rx) { + Op::UserTurn { items, summary, .. } => (items, summary), other => panic!("expected Op::UserTurn, got {other:?}"), }; assert_eq!( @@ -761,6 +763,7 @@ async fn enter_with_only_remote_images_submits_user_turn() { image_url: remote_url.clone(), }] ); + assert_eq!(summary, None); assert!(chat.remote_image_urls().is_empty()); let mut user_cell = None; @@ -6000,6 +6003,62 @@ async fn personality_selection_popup_snapshot() { assert_snapshot!("personality_selection_popup", popup); } +#[cfg(all(not(target_os = "linux"), feature = "voice-input"))] +#[tokio::test] +async fn realtime_audio_selection_popup_snapshot() { + let (mut chat, _rx, _op_rx) = make_chatwidget_manual(Some("gpt-5.2-codex")).await; + chat.open_realtime_audio_popup(); + + let popup = render_bottom_popup(&chat, 80); + assert_snapshot!("realtime_audio_selection_popup", popup); +} + +#[cfg(all(not(target_os = "linux"), feature = "voice-input"))] +#[tokio::test] +async fn realtime_audio_selection_popup_narrow_snapshot() { + let (mut chat, _rx, _op_rx) = make_chatwidget_manual(Some("gpt-5.2-codex")).await; + chat.open_realtime_audio_popup(); + + let popup = render_bottom_popup(&chat, 56); + assert_snapshot!("realtime_audio_selection_popup_narrow", popup); +} + +#[cfg(all(not(target_os = "linux"), feature = "voice-input"))] +#[tokio::test] +async fn realtime_microphone_picker_popup_snapshot() { + let (mut chat, _rx, _op_rx) = make_chatwidget_manual(Some("gpt-5.2-codex")).await; + chat.config.realtime_audio.microphone = Some("Studio Mic".to_string()); + chat.open_realtime_audio_device_selection_with_names( + RealtimeAudioDeviceKind::Microphone, + vec!["Built-in Mic".to_string(), "USB Mic".to_string()], + ); + + let popup = render_bottom_popup(&chat, 80); + assert_snapshot!("realtime_microphone_picker_popup", popup); +} + +#[cfg(all(not(target_os = "linux"), feature = "voice-input"))] +#[tokio::test] +async fn realtime_audio_picker_emits_persist_event() { + let (mut chat, mut rx, _op_rx) = make_chatwidget_manual(Some("gpt-5.2-codex")).await; + chat.open_realtime_audio_device_selection_with_names( + RealtimeAudioDeviceKind::Speaker, + vec!["Desk Speakers".to_string(), "Headphones".to_string()], + ); + + chat.handle_key_event(KeyEvent::new(KeyCode::Down, KeyModifiers::NONE)); + chat.handle_key_event(KeyEvent::new(KeyCode::Down, KeyModifiers::NONE)); + chat.handle_key_event(KeyEvent::new(KeyCode::Enter, KeyModifiers::NONE)); + + assert_matches!( + rx.try_recv(), + Ok(AppEvent::PersistRealtimeAudioDeviceSelection { + kind: RealtimeAudioDeviceKind::Speaker, + name: Some(name), + }) if name == "Headphones" + ); +} + #[tokio::test] async fn model_picker_hides_show_in_picker_false_models_from_cache() { let (mut chat, _rx, _op_rx) = make_chatwidget_manual(Some("test-visible-model")).await; @@ -6018,6 +6077,7 @@ async fn model_picker_hides_show_in_picker_false_models_from_cache() { is_default: false, upgrade: None, show_in_picker, + availability_nux: None, supported_in_api: true, input_modalities: default_input_modalities(), }; @@ -6286,6 +6346,7 @@ async fn single_reasoning_option_skips_selection() { is_default: false, upgrade: None, show_in_picker: true, + availability_nux: None, supported_in_api: true, input_modalities: default_input_modalities(), }; diff --git a/codex-rs/tui/src/history_cell.rs b/codex-rs/tui/src/history_cell.rs index affbb7297d5..624d47d9154 100644 --- a/codex-rs/tui/src/history_cell.rs +++ b/codex-rs/tui/src/history_cell.rs @@ -2633,6 +2633,7 @@ mod tests { enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }; let mut servers = config.mcp_servers.get().clone(); servers.insert("docs".to_string(), stdio_config); @@ -2656,6 +2657,7 @@ mod tests { enabled_tools: None, disabled_tools: None, scopes: None, + oauth_resource: None, }; servers.insert("http".to_string(), http_config); config diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index d51f42b2fff..8dfdbcd8717 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -31,8 +31,10 @@ use codex_core::find_thread_path_by_name_str; use codex_core::format_exec_policy_error_with_source; use codex_core::path_utils; use codex_core::read_session_meta_line; +use codex_core::state_db::get_state_db; use codex_core::terminal::Multiplexer; use codex_core::windows_sandbox::WindowsSandboxLevelExt; +use codex_protocol::ThreadId; use codex_protocol::config_types::AltScreenMode; use codex_protocol::config_types::SandboxMode; use codex_protocol::config_types::WindowsSandboxLevel; @@ -61,6 +63,8 @@ mod app_backtrack; mod app_event; mod app_event_sender; mod ascii_animation; +#[cfg(all(not(target_os = "linux"), feature = "voice-input"))] +mod audio_device; mod bottom_pane; mod chatwidget; mod cli; @@ -121,6 +125,7 @@ mod voice; mod voice { use crate::app_event::AppEvent; use crate::app_event_sender::AppEventSender; + use codex_core::config::Config; use codex_protocol::protocol::RealtimeAudioFrame; use std::sync::Arc; use std::sync::Mutex; @@ -144,7 +149,7 @@ mod voice { Err("voice input is unavailable in this build".to_string()) } - pub fn start_realtime(_tx: AppEventSender) -> Result { + pub fn start_realtime(_config: &Config, _tx: AppEventSender) -> Result { Err("voice input is unavailable in this build".to_string()) } @@ -184,7 +189,7 @@ mod voice { } impl RealtimeAudioPlayer { - pub(crate) fn start() -> Result { + pub(crate) fn start(_config: &Config) -> Result { Err("voice output is unavailable in this build".to_string()) } @@ -665,7 +670,19 @@ async fn run_ratatui_app( find_thread_path_by_name_str(&config.codex_home, id_str).await? }; match path { - Some(path) => resume_picker::SessionSelection::Fork(path), + Some(path) => { + let thread_id = + match resolve_session_thread_id(path.as_path(), is_uuid.then_some(id_str)) + .await + { + Some(thread_id) => thread_id, + None => return missing_session_exit(id_str, "fork"), + }; + resume_picker::SessionSelection::Fork(resume_picker::SessionTarget { + path, + thread_id, + }) + } None => return missing_session_exit(id_str, "fork"), } } else if cli.fork_last { @@ -682,11 +699,37 @@ async fn run_ratatui_app( ) .await { - Ok(page) => page - .items - .first() - .map(|it| resume_picker::SessionSelection::Fork(it.path.clone())) - .unwrap_or(resume_picker::SessionSelection::StartFresh), + Ok(page) => match page.items.first() { + Some(item) => { + match resolve_session_thread_id(item.path.as_path(), None).await { + Some(thread_id) => resume_picker::SessionSelection::Fork( + resume_picker::SessionTarget { + path: item.path.clone(), + thread_id, + }, + ), + None => { + let rollout_path = item.path.display(); + error!( + "Error reading session metadata from latest rollout: {rollout_path}" + ); + restore(); + session_log::log_session_end(); + let _ = tui.terminal.clear(); + return Ok(AppExitInfo { + token_usage: codex_protocol::protocol::TokenUsage::default(), + thread_id: None, + thread_name: None, + update_action: None, + exit_reason: ExitReason::Fatal(format!( + "Found latest saved session at {rollout_path}, but failed to read its metadata. Run `codex fork` to choose from existing sessions." + )), + }); + } + } + } + None => resume_picker::SessionSelection::StartFresh, + }, Err(_) => resume_picker::SessionSelection::StartFresh, } } else if cli.fork_picker { @@ -715,7 +758,21 @@ async fn run_ratatui_app( find_thread_path_by_name_str(&config.codex_home, id_str).await? }; match path { - Some(path) => resume_picker::SessionSelection::Resume(path), + Some(path) => { + let thread_id = match resolve_session_thread_id( + path.as_path(), + is_uuid.then_some(id_str), + ) + .await + { + Some(thread_id) => thread_id, + None => return missing_session_exit(id_str, "resume"), + }; + resume_picker::SessionSelection::Resume(resume_picker::SessionTarget { + path, + thread_id, + }) + } None => return missing_session_exit(id_str, "resume"), } } else if cli.resume_last { @@ -737,7 +794,30 @@ async fn run_ratatui_app( ) .await { - Ok(Some(path)) => resume_picker::SessionSelection::Resume(path), + Ok(Some(path)) => match resolve_session_thread_id(path.as_path(), None).await { + Some(thread_id) => { + resume_picker::SessionSelection::Resume(resume_picker::SessionTarget { + path, + thread_id, + }) + } + None => { + let rollout_path = path.display(); + error!("Error reading session metadata from latest rollout: {rollout_path}"); + restore(); + session_log::log_session_end(); + let _ = tui.terminal.clear(); + return Ok(AppExitInfo { + token_usage: codex_protocol::protocol::TokenUsage::default(), + thread_id: None, + thread_name: None, + update_action: None, + exit_reason: ExitReason::Fatal(format!( + "Found latest saved session at {rollout_path}, but failed to read its metadata. Run `codex resume` to choose from existing sessions." + )), + }); + } + }, _ => resume_picker::SessionSelection::StartFresh, } } else if cli.resume_picker { @@ -761,15 +841,27 @@ async fn run_ratatui_app( let current_cwd = config.cwd.clone(); let allow_prompt = cli.cwd.is_none(); - let action_and_path_if_resume_or_fork = match &session_selection { - resume_picker::SessionSelection::Resume(path) => Some((CwdPromptAction::Resume, path)), - resume_picker::SessionSelection::Fork(path) => Some((CwdPromptAction::Fork, path)), + let action_and_target_session_if_resume_or_fork = match &session_selection { + resume_picker::SessionSelection::Resume(target_session) => { + Some((CwdPromptAction::Resume, target_session)) + } + resume_picker::SessionSelection::Fork(target_session) => { + Some((CwdPromptAction::Fork, target_session)) + } _ => None, }; - let fallback_cwd = match action_and_path_if_resume_or_fork { - Some((action, path)) => { - match resolve_cwd_for_resume_or_fork(&mut tui, ¤t_cwd, path, action, allow_prompt) - .await? + let fallback_cwd = match action_and_target_session_if_resume_or_fork { + Some((action, target_session)) => { + match resolve_cwd_for_resume_or_fork( + &mut tui, + &config, + ¤t_cwd, + target_session.thread_id, + &target_session.path, + action, + allow_prompt, + ) + .await? { ResolveCwdOutcome::Continue(cwd) => cwd, ResolveCwdOutcome::Exit => { @@ -851,12 +943,35 @@ async fn run_ratatui_app( app_result } -pub(crate) async fn read_session_cwd(path: &Path) -> Option { +pub(crate) async fn resolve_session_thread_id( + path: &Path, + id_str_if_uuid: Option<&str>, +) -> Option { + match id_str_if_uuid { + Some(id_str) => ThreadId::from_string(id_str).ok(), + None => read_session_meta_line(path) + .await + .ok() + .map(|meta_line| meta_line.meta.id), + } +} + +pub(crate) async fn read_session_cwd( + config: &Config, + thread_id: ThreadId, + path: &Path, +) -> Option { + if let Some(state_db_ctx) = get_state_db(config, None).await + && let Ok(Some(metadata)) = state_db_ctx.get_thread(thread_id).await + { + return Some(metadata.cwd); + } + // Prefer the latest TurnContext cwd so resume/fork reflects the most recent - // session directory (for the changed-cwd prompt). The alternative would be - // mutating the SessionMeta line when the session cwd changes, but the rollout - // is an append-only JSONL log and rewriting the head would be error-prone. - // When rollouts move to SQLite, we can drop this scan. + // session directory (for the changed-cwd prompt) when DB data is unavailable. + // The alternative would be mutating the SessionMeta line when the session cwd + // changes, but the rollout is an append-only JSONL log and rewriting the head + // would be error-prone. if let Some(cwd) = parse_latest_turn_context_cwd(path).await { return Some(cwd); } @@ -908,12 +1023,14 @@ pub(crate) enum ResolveCwdOutcome { pub(crate) async fn resolve_cwd_for_resume_or_fork( tui: &mut Tui, + config: &Config, current_cwd: &Path, + thread_id: ThreadId, path: &Path, action: CwdPromptAction, allow_prompt: bool, ) -> color_eyre::Result { - let Some(history_cwd) = read_session_cwd(path).await else { + let Some(history_cwd) = read_session_cwd(config, thread_id, path).await else { return Ok(ResolveCwdOutcome::Continue(None)); }; if allow_prompt && cwds_differ(current_cwd, &history_cwd) { @@ -1071,11 +1188,14 @@ mod tests { use codex_core::config::ConfigBuilder; use codex_core::config::ConfigOverrides; use codex_core::config::ProjectConfig; + use codex_core::features::Feature; + use codex_protocol::ThreadId; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::RolloutItem; use codex_protocol::protocol::RolloutLine; use codex_protocol::protocol::SessionMeta; use codex_protocol::protocol::SessionMetaLine; + use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::TurnContextItem; use serial_test::serial; use tempfile::TempDir; @@ -1152,6 +1272,8 @@ mod tests { TurnContextItem { turn_id: None, cwd, + current_date: None, + timezone: None, approval_policy: config.permissions.approval_policy.value(), sandbox_policy: config.permissions.sandbox_policy.get().clone(), network: None, @@ -1159,7 +1281,9 @@ mod tests { personality: None, collaboration_mode: None, effort: config.model_reasoning_effort, - summary: config.model_reasoning_summary, + summary: config + .model_reasoning_summary + .unwrap_or(codex_protocol::config_types::ReasoningSummary::Auto), user_instructions: None, developer_instructions: None, final_output_json_schema: None, @@ -1194,7 +1318,9 @@ mod tests { } std::fs::write(&rollout_path, text)?; - let cwd = read_session_cwd(&rollout_path).await.expect("expected cwd"); + let cwd = read_session_cwd(&config, ThreadId::new(), &rollout_path) + .await + .expect("expected cwd"); assert_eq!(cwd, second); Ok(()) } @@ -1234,7 +1360,9 @@ mod tests { } std::fs::write(&rollout_path, text)?; - let session_cwd = read_session_cwd(&rollout_path).await.expect("expected cwd"); + let session_cwd = read_session_cwd(&config, ThreadId::new(), &rollout_path) + .await + .expect("expected cwd"); assert_eq!(session_cwd, latest); assert!(cwds_differ(¤t, &session_cwd)); Ok(()) @@ -1339,7 +1467,7 @@ trust_level = "untrusted" #[tokio::test] async fn read_session_cwd_falls_back_to_session_meta() -> std::io::Result<()> { let temp_dir = TempDir::new()?; - let _config = build_config(&temp_dir).await?; + let config = build_config(&temp_dir).await?; let session_cwd = temp_dir.path().join("session"); std::fs::create_dir_all(&session_cwd)?; @@ -1361,8 +1489,67 @@ trust_level = "untrusted" ); std::fs::write(&rollout_path, text)?; - let cwd = read_session_cwd(&rollout_path).await.expect("expected cwd"); + let cwd = read_session_cwd(&config, ThreadId::new(), &rollout_path) + .await + .expect("expected cwd"); assert_eq!(cwd, session_cwd); Ok(()) } + + #[tokio::test] + async fn read_session_cwd_prefers_sqlite_when_thread_id_present() -> std::io::Result<()> { + let temp_dir = TempDir::new()?; + let mut config = build_config(&temp_dir).await?; + config.features.enable(Feature::Sqlite); + + let thread_id = ThreadId::new(); + let rollout_cwd = temp_dir.path().join("rollout-cwd"); + let sqlite_cwd = temp_dir.path().join("sqlite-cwd"); + std::fs::create_dir_all(&rollout_cwd)?; + std::fs::create_dir_all(&sqlite_cwd)?; + + let rollout_path = temp_dir.path().join("rollout.jsonl"); + let rollout_line = RolloutLine { + timestamp: "t0".to_string(), + item: RolloutItem::TurnContext(build_turn_context(&config, rollout_cwd)), + }; + std::fs::write( + &rollout_path, + format!( + "{}\n", + serde_json::to_string(&rollout_line).expect("serialize rollout") + ), + )?; + + let runtime = codex_state::StateRuntime::init( + config.codex_home.clone(), + config.model_provider_id.clone(), + None, + ) + .await + .map_err(std::io::Error::other)?; + runtime + .mark_backfill_complete(None) + .await + .map_err(std::io::Error::other)?; + + let mut builder = codex_state::ThreadMetadataBuilder::new( + thread_id, + rollout_path.clone(), + chrono::Utc::now(), + SessionSource::Cli, + ); + builder.cwd = sqlite_cwd.clone(); + let metadata = builder.build(config.model_provider_id.as_str()); + runtime + .upsert_thread(&metadata) + .await + .map_err(std::io::Error::other)?; + + let cwd = read_session_cwd(&config, thread_id, &rollout_path) + .await + .expect("expected cwd"); + assert_eq!(cwd, sqlite_cwd); + Ok(()) + } } diff --git a/codex-rs/tui/src/resume_picker.rs b/codex-rs/tui/src/resume_picker.rs index a827ccab50b..ad815628487 100644 --- a/codex-rs/tui/src/resume_picker.rs +++ b/codex-rs/tui/src/resume_picker.rs @@ -39,11 +39,18 @@ use unicode_width::UnicodeWidthStr; const PAGE_SIZE: usize = 25; const LOAD_NEAR_THRESHOLD: usize = 5; + +#[derive(Debug, Clone)] +pub struct SessionTarget { + pub path: PathBuf, + pub thread_id: ThreadId, +} + #[derive(Debug, Clone)] pub enum SessionSelection { StartFresh, - Resume(PathBuf), - Fork(PathBuf), + Resume(SessionTarget), + Fork(SessionTarget), Exit, } @@ -68,10 +75,11 @@ impl SessionPickerAction { } } - fn selection(self, path: PathBuf) -> SessionSelection { + fn selection(self, path: PathBuf, thread_id: ThreadId) -> SessionSelection { + let target_session = SessionTarget { path, thread_id }; match self { - SessionPickerAction::Resume => SessionSelection::Resume(path), - SessionPickerAction::Fork => SessionSelection::Fork(path), + SessionPickerAction::Resume => SessionSelection::Resume(target_session), + SessionPickerAction::Fork => SessionSelection::Fork(target_session), } } } @@ -266,6 +274,7 @@ struct PickerState { action: SessionPickerAction, sort_key: ThreadSortKey, thread_name_cache: HashMap>, + inline_error: Option, } struct PaginationState { @@ -383,6 +392,7 @@ impl PickerState { action, sort_key: ThreadSortKey::CreatedAt, thread_name_cache: HashMap::new(), + inline_error: None, } } @@ -391,6 +401,7 @@ impl PickerState { } async fn handle_key(&mut self, key: KeyEvent) -> Result> { + self.inline_error = None; match key.code { KeyCode::Esc => return Ok(Some(SessionSelection::StartFresh)), KeyCode::Char('c') @@ -402,7 +413,19 @@ impl PickerState { } KeyCode::Enter => { if let Some(row) = self.filtered_rows.get(self.selected) { - return Ok(Some(self.action.selection(row.path.clone()))); + let path = row.path.clone(); + let thread_id = match row.thread_id { + Some(thread_id) => Some(thread_id), + None => crate::resolve_session_thread_id(path.as_path(), None).await, + }; + if let Some(thread_id) = thread_id { + return Ok(Some(self.action.selection(path, thread_id))); + } + self.inline_error = Some(format!( + "Failed to read session metadata from {}", + path.display() + )); + self.request_frame(); } } KeyCode::Up => { @@ -866,12 +889,7 @@ fn draw_picker(tui: &mut Tui, state: &PickerState) -> std::io::Result<()> { frame.render_widget_ref(header_line, header); // Search line - let q = if state.query.is_empty() { - "Type to search".dim().to_string() - } else { - format!("Search: {}", state.query) - }; - frame.render_widget_ref(Line::from(q), search); + frame.render_widget_ref(search_line(state), search); let metrics = calculate_column_metrics(&state.filtered_rows, state.show_all); @@ -904,6 +922,16 @@ fn draw_picker(tui: &mut Tui, state: &PickerState) -> std::io::Result<()> { }) } +fn search_line(state: &PickerState) -> Line<'_> { + if let Some(error) = state.inline_error.as_deref() { + return Line::from(error.red()); + } + if state.query.is_empty() { + return Line::from("Type to search".dim()); + } + Line::from(format!("Search: {}", state.query)) +} + fn render_list( frame: &mut crate::custom_terminal::Frame, area: Rect, @@ -1607,6 +1635,42 @@ mod tests { assert_snapshot!("resume_picker_table", snapshot); } + #[test] + fn resume_search_error_snapshot() { + use crate::custom_terminal::Terminal; + use crate::test_backend::VT100Backend; + + let loader: PageLoader = Arc::new(|_| {}); + let mut state = PickerState::new( + PathBuf::from("/tmp"), + FrameRequester::test_dummy(), + loader, + String::from("openai"), + true, + None, + SessionPickerAction::Resume, + ); + state.inline_error = Some(String::from( + "Failed to read session metadata from /tmp/missing.jsonl", + )); + + let width: u16 = 80; + let height: u16 = 1; + let backend = VT100Backend::new(width, height); + let mut terminal = Terminal::with_options(backend).expect("terminal"); + terminal.set_viewport_area(Rect::new(0, 0, width, height)); + + { + let mut frame = terminal.get_frame(); + let line = search_line(&state); + frame.render_widget_ref(line, frame.area()); + } + terminal.flush().expect("flush"); + + let snapshot = terminal.backend().to_string(); + assert_snapshot!("resume_picker_search_error", snapshot); + } + // TODO(jif) fix // #[tokio::test] // async fn resume_picker_screen_snapshot() { @@ -2102,6 +2166,46 @@ mod tests { assert_eq!(state.selected, 5); } + #[tokio::test] + async fn enter_on_row_without_resolvable_thread_id_shows_inline_error() { + let loader: PageLoader = Arc::new(|_| {}); + let mut state = PickerState::new( + PathBuf::from("/tmp"), + FrameRequester::test_dummy(), + loader, + String::from("openai"), + true, + None, + SessionPickerAction::Resume, + ); + + let row = Row { + path: PathBuf::from("/tmp/missing.jsonl"), + preview: String::from("missing metadata"), + thread_id: None, + thread_name: None, + created_at: None, + updated_at: None, + cwd: None, + git_branch: None, + }; + state.all_rows = vec![row.clone()]; + state.filtered_rows = vec![row]; + + let selection = state + .handle_key(KeyEvent::new(KeyCode::Enter, KeyModifiers::NONE)) + .await + .expect("enter should not abort the picker"); + + assert!(selection.is_none()); + assert_eq!( + state.inline_error, + Some(String::from( + "Failed to read session metadata from /tmp/missing.jsonl" + )) + ); + } + #[tokio::test] async fn up_at_bottom_does_not_scroll_when_visible() { let loader: PageLoader = Arc::new(|_| {}); diff --git a/codex-rs/tui/src/slash_command.rs b/codex-rs/tui/src/slash_command.rs index 2799c80b247..bbd0307da3f 100644 --- a/codex-rs/tui/src/slash_command.rs +++ b/codex-rs/tui/src/slash_command.rs @@ -51,6 +51,7 @@ pub enum SlashCommand { Clear, Personality, Realtime, + Settings, TestApproval, // Debugging commands. #[strum(serialize = "debug-m-drop")] @@ -89,6 +90,7 @@ impl SlashCommand { SlashCommand::Model => "choose what model and reasoning effort to use", SlashCommand::Personality => "choose a communication style for Codex", SlashCommand::Realtime => "toggle realtime voice mode (experimental)", + SlashCommand::Settings => "configure realtime microphone/speaker", SlashCommand::Plan => "switch to Plan mode", SlashCommand::Collab => "change collaboration mode (experimental)", SlashCommand::Agent => "switch the active agent thread", @@ -163,6 +165,7 @@ impl SlashCommand { SlashCommand::Rollout => true, SlashCommand::TestApproval => true, SlashCommand::Realtime => true, + SlashCommand::Settings => true, SlashCommand::Collab => true, SlashCommand::Agent => true, SlashCommand::Statusline => false, diff --git a/codex-rs/tui/src/snapshots/codex_tui__resume_picker__tests__resume_picker_search_error.snap b/codex-rs/tui/src/snapshots/codex_tui__resume_picker__tests__resume_picker_search_error.snap new file mode 100644 index 00000000000..4f48ef96af6 --- /dev/null +++ b/codex-rs/tui/src/snapshots/codex_tui__resume_picker__tests__resume_picker_search_error.snap @@ -0,0 +1,5 @@ +--- +source: tui/src/resume_picker.rs +expression: snapshot +--- +Failed to read session metadata from /tmp/missing.jsonl diff --git a/codex-rs/tui/src/status/card.rs b/codex-rs/tui/src/status/card.rs index acfa604e97c..a7546230bf5 100644 --- a/codex-rs/tui/src/status/card.rs +++ b/codex-rs/tui/src/status/card.rs @@ -185,7 +185,10 @@ impl StatusHistoryCell { config_entries.push(("reasoning effort", effort_value)); config_entries.push(( "reasoning summaries", - config.model_reasoning_summary.to_string(), + config + .model_reasoning_summary + .map(|summary| summary.to_string()) + .unwrap_or_else(|| "auto".to_string()), )); } let (model_name, model_details) = compose_model_display(model_name, &config_entries); diff --git a/codex-rs/tui/src/status/tests.rs b/codex-rs/tui/src/status/tests.rs index fb1aad72753..16ab60bd72e 100644 --- a/codex-rs/tui/src/status/tests.rs +++ b/codex-rs/tui/src/status/tests.rs @@ -96,7 +96,7 @@ async fn status_snapshot_includes_reasoning_details() { let mut config = test_config(&temp_home).await; config.model = Some("gpt-5.1-codex-max".to_string()); config.model_provider_id = "openai".to_string(); - config.model_reasoning_summary = ReasoningSummary::Detailed; + config.model_reasoning_summary = Some(ReasoningSummary::Detailed); config .permissions .sandbox_policy @@ -596,7 +596,7 @@ async fn status_snapshot_truncates_in_narrow_terminal() { let mut config = test_config(&temp_home).await; config.model = Some("gpt-5.1-codex-max".to_string()); config.model_provider_id = "openai".to_string(); - config.model_reasoning_summary = ReasoningSummary::Detailed; + config.model_reasoning_summary = Some(ReasoningSummary::Detailed); config.cwd = PathBuf::from("/workspace/tests"); let auth_manager = test_auth_manager(&config); diff --git a/codex-rs/tui/src/streaming/commit_tick.rs b/codex-rs/tui/src/streaming/commit_tick.rs index bda63bccf63..6c287ff9c37 100644 --- a/codex-rs/tui/src/streaming/commit_tick.rs +++ b/codex-rs/tui/src/streaming/commit_tick.rs @@ -83,21 +83,11 @@ pub(crate) fn run_commit_tick( return CommitTickOutput::default(); } - let output = apply_commit_tick_plan( + apply_commit_tick_plan( decision.drain_plan, stream_controller, plan_stream_controller, - ); - tracing::trace!( - mode = ?decision.mode, - queued_lines = snapshot.queued_lines, - oldest_queued_age_ms = snapshot.oldest_age.map(|age| age.as_millis() as u64), - drain_plan = ?decision.drain_plan, - has_controller = output.has_controller, - all_idle = output.all_idle, - "stream chunking commit tick" - ); - output + ) } /// Builds the combined queue-pressure snapshot consumed by chunking policy. diff --git a/codex-rs/tui/src/voice.rs b/codex-rs/tui/src/voice.rs index 443ccc88a0d..6c4236ac810 100644 --- a/codex-rs/tui/src/voice.rs +++ b/codex-rs/tui/src/voice.rs @@ -51,7 +51,7 @@ pub struct VoiceCapture { impl VoiceCapture { pub fn start() -> Result { - let (device, config) = select_input_device_and_config()?; + let (device, config) = select_default_input_device_and_config()?; let sample_rate = config.sample_rate().0; let channels = config.channels(); @@ -74,8 +74,8 @@ impl VoiceCapture { }) } - pub fn start_realtime(tx: AppEventSender) -> Result { - let (device, config) = select_input_device_and_config()?; + pub fn start_realtime(config: &Config, tx: AppEventSender) -> Result { + let (device, config) = select_realtime_input_device_and_config(config)?; let sample_rate = config.sample_rate().0; let channels = config.channels(); @@ -262,7 +262,8 @@ pub fn transcribe_async( // Voice input helpers // ------------------------- -fn select_input_device_and_config() -> Result<(cpal::Device, cpal::SupportedStreamConfig), String> { +fn select_default_input_device_and_config() +-> Result<(cpal::Device, cpal::SupportedStreamConfig), String> { let host = cpal::default_host(); let device = host .default_input_device() @@ -273,6 +274,12 @@ fn select_input_device_and_config() -> Result<(cpal::Device, cpal::SupportedStre Ok((device, config)) } +fn select_realtime_input_device_and_config( + config: &Config, +) -> Result<(cpal::Device, cpal::SupportedStreamConfig), String> { + crate::audio_device::select_configured_input_device_and_config(config) +} + fn build_input_stream( device: &cpal::Device, config: &cpal::SupportedStreamConfig, @@ -466,14 +473,9 @@ pub(crate) struct RealtimeAudioPlayer { } impl RealtimeAudioPlayer { - pub(crate) fn start() -> Result { - let host = cpal::default_host(); - let device = host - .default_output_device() - .ok_or_else(|| "no output audio device available".to_string())?; - let config = device - .default_output_config() - .map_err(|e| format!("failed to get default output config: {e}"))?; + pub(crate) fn start(config: &Config) -> Result { + let (device, config) = + crate::audio_device::select_configured_output_device_and_config(config)?; let output_sample_rate = config.sample_rate().0; let output_channels = config.channels(); let queue = Arc::new(Mutex::new(VecDeque::new())); diff --git a/codex-rs/utils/sandbox-summary/src/config_summary.rs b/codex-rs/utils/sandbox-summary/src/config_summary.rs index a83a65723bd..0f6b1929790 100644 --- a/codex-rs/utils/sandbox-summary/src/config_summary.rs +++ b/codex-rs/utils/sandbox-summary/src/config_summary.rs @@ -28,7 +28,10 @@ pub fn create_config_summary_entries(config: &Config, model: &str) -> Vec<(&'sta )); entries.push(( "reasoning summaries", - config.model_reasoning_summary.to_string(), + config + .model_reasoning_summary + .map(|summary| summary.to_string()) + .unwrap_or_else(|| "none".to_string()), )); } diff --git a/codex-rs/windows-sandbox-rs/src/setup_orchestrator.rs b/codex-rs/windows-sandbox-rs/src/setup_orchestrator.rs index e7f9bee69f2..d612ea6583a 100644 --- a/codex-rs/windows-sandbox-rs/src/setup_orchestrator.rs +++ b/codex-rs/windows-sandbox-rs/src/setup_orchestrator.rs @@ -38,6 +38,18 @@ pub const ONLINE_USERNAME: &str = "CodexSandboxOnline"; const ERROR_CANCELLED: u32 = 1223; const SECURITY_BUILTIN_DOMAIN_RID: u32 = 0x0000_0020; const DOMAIN_ALIAS_RID_ADMINS: u32 = 0x0000_0220; +const USERPROFILE_READ_ROOT_EXCLUSIONS: &[&str] = &[ + ".ssh", + ".gnupg", + ".aws", + ".azure", + ".kube", + ".docker", + ".config", + ".npm", + ".pki", + ".terraform.d", +]; pub fn sandbox_dir(codex_home: &Path) -> PathBuf { codex_home.join(".sandbox") @@ -245,6 +257,25 @@ fn canonical_existing(paths: &[PathBuf]) -> Vec { .collect() } +fn profile_read_roots(user_profile: &Path) -> Vec { + let entries = match std::fs::read_dir(user_profile) { + Ok(entries) => entries, + Err(_) => return vec![user_profile.to_path_buf()], + }; + + entries + .filter_map(Result::ok) + .map(|entry| (entry.file_name(), entry.path())) + .filter(|(name, _)| { + let name = name.to_string_lossy(); + !USERPROFILE_READ_ROOT_EXCLUSIONS + .iter() + .any(|excluded| name.eq_ignore_ascii_case(excluded)) + }) + .map(|(_, path)| path) + .collect() +} + pub(crate) fn gather_read_roots(command_cwd: &Path, policy: &SandboxPolicy) -> Vec { let mut roots: Vec = Vec::new(); if let Ok(exe) = std::env::current_exe() { @@ -261,7 +292,7 @@ pub(crate) fn gather_read_roots(command_cwd: &Path, policy: &SandboxPolicy) -> V roots.push(p); } if let Ok(up) = std::env::var("USERPROFILE") { - roots.push(PathBuf::from(up)); + roots.extend(profile_read_roots(Path::new(&up))); } roots.push(command_cwd.to_path_buf()); if let SandboxPolicy::WorkspaceWrite { writable_roots, .. } = policy { @@ -578,3 +609,44 @@ fn filter_sensitive_write_roots(mut roots: Vec, codex_home: &Path) -> V }); roots } + +#[cfg(test)] +mod tests { + use super::profile_read_roots; + use pretty_assertions::assert_eq; + use std::collections::HashSet; + use std::fs; + use std::path::PathBuf; + use tempfile::TempDir; + + #[test] + fn profile_read_roots_excludes_configured_top_level_entries() { + let tmp = TempDir::new().expect("tempdir"); + let user_profile = tmp.path(); + let allowed_dir = user_profile.join("Documents"); + let allowed_file = user_profile.join(".gitconfig"); + let excluded_dir = user_profile.join(".ssh"); + let excluded_case_variant = user_profile.join(".AWS"); + + fs::create_dir_all(&allowed_dir).expect("create allowed dir"); + fs::write(&allowed_file, "safe").expect("create allowed file"); + fs::create_dir_all(&excluded_dir).expect("create excluded dir"); + fs::create_dir_all(&excluded_case_variant).expect("create excluded case variant"); + + let roots = profile_read_roots(user_profile); + let actual: HashSet = roots.into_iter().collect(); + let expected: HashSet = [allowed_dir, allowed_file].into_iter().collect(); + + assert_eq!(expected, actual); + } + + #[test] + fn profile_read_roots_falls_back_to_profile_root_when_enumeration_fails() { + let tmp = TempDir::new().expect("tempdir"); + let missing_profile = tmp.path().join("missing-user-profile"); + + let roots = profile_read_roots(&missing_profile); + + assert_eq!(vec![missing_profile], roots); + } +} diff --git a/docs/config.md b/docs/config.md index 30665bb11ba..fc9d62b8e80 100644 --- a/docs/config.md +++ b/docs/config.md @@ -24,6 +24,8 @@ Codex can run a notification hook when the agent finishes a turn. See the config - https://developers.openai.com/codex/config-reference +When Codex knows which client started the turn, the legacy notify JSON payload also includes a top-level `client` field. The TUI reports `codex-tui`, and the app server reports the `clientInfo.name` value from `initialize`. + ## JSON Schema The generated JSON Schema for `config.toml` lives at `codex-rs/core/config.schema.json`. diff --git a/docs/js_repl.md b/docs/js_repl.md index eb0ea84ffd9..b36da5dc082 100644 --- a/docs/js_repl.md +++ b/docs/js_repl.md @@ -65,10 +65,38 @@ For `CODEX_JS_REPL_NODE_MODULE_DIRS` and `js_repl_node_module_dirs`, module reso - `codex.tmpDir`: per-session scratch directory path. - `codex.tool(name, args?)`: executes a normal Codex tool call from inside `js_repl` (including shell tools like `shell` / `shell_command` when available). +- Each `codex.tool(...)` call emits a bounded summary at `info` level from the `codex_core::tools::js_repl` logger. At `trace` level, the same path also logs the exact raw response object or error string seen by JavaScript. - To share generated images with the model, write a file under `codex.tmpDir`, call `await codex.tool("view_image", { path: "/absolute/path" })`, then delete the file. Avoid writing directly to `process.stdout` / `process.stderr` / `process.stdin`; the kernel uses a JSON-line transport over stdio. +## Debug logging + +Nested `codex.tool(...)` diagnostics are emitted through normal `tracing` output instead of rollout history. + +- `info` level logs a bounded summary. +- `trace` level also logs the exact serialized response object or error string seen by JavaScript. + +For `codex app-server`, these logs are written to the server process `stderr`. + +Examples: + +```sh +RUST_LOG=codex_core::tools::js_repl=info \ +LOG_FORMAT=json \ +codex app-server \ +2> /tmp/codex-app-server.log +``` + +```sh +RUST_LOG=codex_core::tools::js_repl=trace \ +LOG_FORMAT=json \ +codex app-server \ +2> /tmp/codex-app-server.log +``` + +In both cases, inspect `/tmp/codex-app-server.log` or whatever sink captures the process `stderr`. + ## Vendored parser asset (`meriyah.umd.min.js`) The kernel embeds a vendored Meriyah bundle at: