From b2263a7db685bd6eb73b9803f62e5b4d6ae5e67f Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Thu, 5 Mar 2026 22:30:52 -0500 Subject: [PATCH 1/7] Add support for full web search config --- .../codex_app_server_protocol.schemas.json | 123 ++++++++++++++++++ .../codex_app_server_protocol.v2.schemas.json | 123 ++++++++++++++++++ .../schema/json/v2/ConfigReadResponse.json | 123 ++++++++++++++++++ .../schema/typescript/WebSearchConfig.ts | 8 ++ .../schema/typescript/WebSearchContextSize.ts | 5 + .../schema/typescript/WebSearchFilters.ts | 5 + .../typescript/WebSearchUserLocation.ts | 6 + .../typescript/WebSearchUserLocationType.ts | 5 + .../schema/typescript/index.ts | 5 + .../schema/typescript/v2/Config.ts | 3 +- .../schema/typescript/v2/ProfileV2.ts | 3 +- .../app-server-protocol/src/protocol/v2.rs | 3 + .../app-server/tests/suite/v2/config_rpc.rs | 62 +++++++++ codex-rs/core/config.schema.json | 78 +++++++++++ codex-rs/core/src/client_common.rs | 77 +++++++++++ codex-rs/core/src/codex.rs | 3 + codex-rs/core/src/config/mod.rs | 69 ++++++++++ codex-rs/core/src/config/profile.rs | 2 + codex-rs/core/src/tools/spec.rs | 74 +++++++++++ codex-rs/protocol/src/config_types.rs | 44 +++++++ 20 files changed, 819 insertions(+), 2 deletions(-) create mode 100644 codex-rs/app-server-protocol/schema/typescript/WebSearchConfig.ts create mode 100644 codex-rs/app-server-protocol/schema/typescript/WebSearchContextSize.ts create mode 100644 codex-rs/app-server-protocol/schema/typescript/WebSearchFilters.ts create mode 100644 codex-rs/app-server-protocol/schema/typescript/WebSearchUserLocation.ts create mode 100644 codex-rs/app-server-protocol/schema/typescript/WebSearchUserLocationType.ts diff --git a/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.schemas.json b/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.schemas.json index 0bebb007cb8..debeacd2de1 100644 --- a/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.schemas.json +++ b/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.schemas.json @@ -9595,6 +9595,16 @@ "type": "null" } ] + }, + "web_search_config": { + "anyOf": [ + { + "$ref": "#/definitions/v2/WebSearchConfig" + }, + { + "type": "null" + } + ] } }, "type": "object" @@ -11998,6 +12008,16 @@ "type": "null" } ] + }, + "web_search_config": { + "anyOf": [ + { + "$ref": "#/definitions/v2/WebSearchConfig" + }, + { + "type": "null" + } + ] } }, "type": "object" @@ -16330,6 +16350,65 @@ } ] }, + "WebSearchConfig": { + "additionalProperties": false, + "properties": { + "filters": { + "anyOf": [ + { + "$ref": "#/definitions/v2/WebSearchFilters" + }, + { + "type": "null" + } + ] + }, + "search_context_size": { + "anyOf": [ + { + "$ref": "#/definitions/v2/WebSearchContextSize" + }, + { + "type": "null" + } + ] + }, + "user_location": { + "anyOf": [ + { + "$ref": "#/definitions/v2/WebSearchUserLocation" + }, + { + "type": "null" + } + ] + } + }, + "type": "object" + }, + "WebSearchContextSize": { + "enum": [ + "low", + "medium", + "high" + ], + "type": "string" + }, + "WebSearchFilters": { + "additionalProperties": false, + "properties": { + "allowed_domains": { + "items": { + "type": "string" + }, + "type": [ + "array", + "null" + ] + } + }, + "type": "object" + }, "WebSearchMode": { "enum": [ "disabled", @@ -16338,6 +16417,50 @@ ], "type": "string" }, + "WebSearchUserLocation": { + "additionalProperties": false, + "properties": { + "city": { + "type": [ + "string", + "null" + ] + }, + "country": { + "type": [ + "string", + "null" + ] + }, + "region": { + "type": [ + "string", + "null" + ] + }, + "timezone": { + "type": [ + "string", + "null" + ] + }, + "type": { + "allOf": [ + { + "$ref": "#/definitions/v2/WebSearchUserLocationType" + } + ], + "default": "approximate" + } + }, + "type": "object" + }, + "WebSearchUserLocationType": { + "enum": [ + "approximate" + ], + "type": "string" + }, "WindowsSandboxSetupCompletedNotification": { "$schema": "http://json-schema.org/draft-07/schema#", "properties": { diff --git a/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.v2.schemas.json b/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.v2.schemas.json index da67d650c40..cc338997bec 100644 --- a/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.v2.schemas.json +++ b/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.v2.schemas.json @@ -2635,6 +2635,16 @@ "type": "null" } ] + }, + "web_search_config": { + "anyOf": [ + { + "$ref": "#/definitions/WebSearchConfig" + }, + { + "type": "null" + } + ] } }, "type": "object" @@ -8655,6 +8665,16 @@ "type": "null" } ] + }, + "web_search_config": { + "anyOf": [ + { + "$ref": "#/definitions/WebSearchConfig" + }, + { + "type": "null" + } + ] } }, "type": "object" @@ -14586,6 +14606,65 @@ } ] }, + "WebSearchConfig": { + "additionalProperties": false, + "properties": { + "filters": { + "anyOf": [ + { + "$ref": "#/definitions/WebSearchFilters" + }, + { + "type": "null" + } + ] + }, + "search_context_size": { + "anyOf": [ + { + "$ref": "#/definitions/WebSearchContextSize" + }, + { + "type": "null" + } + ] + }, + "user_location": { + "anyOf": [ + { + "$ref": "#/definitions/WebSearchUserLocation" + }, + { + "type": "null" + } + ] + } + }, + "type": "object" + }, + "WebSearchContextSize": { + "enum": [ + "low", + "medium", + "high" + ], + "type": "string" + }, + "WebSearchFilters": { + "additionalProperties": false, + "properties": { + "allowed_domains": { + "items": { + "type": "string" + }, + "type": [ + "array", + "null" + ] + } + }, + "type": "object" + }, "WebSearchMode": { "enum": [ "disabled", @@ -14594,6 +14673,50 @@ ], "type": "string" }, + "WebSearchUserLocation": { + "additionalProperties": false, + "properties": { + "city": { + "type": [ + "string", + "null" + ] + }, + "country": { + "type": [ + "string", + "null" + ] + }, + "region": { + "type": [ + "string", + "null" + ] + }, + "timezone": { + "type": [ + "string", + "null" + ] + }, + "type": { + "allOf": [ + { + "$ref": "#/definitions/WebSearchUserLocationType" + } + ], + "default": "approximate" + } + }, + "type": "object" + }, + "WebSearchUserLocationType": { + "enum": [ + "approximate" + ], + "type": "string" + }, "WindowsSandboxSetupCompletedNotification": { "$schema": "http://json-schema.org/draft-07/schema#", "properties": { diff --git a/codex-rs/app-server-protocol/schema/json/v2/ConfigReadResponse.json b/codex-rs/app-server-protocol/schema/json/v2/ConfigReadResponse.json index dd0a86fe910..a6eb15cb565 100644 --- a/codex-rs/app-server-protocol/schema/json/v2/ConfigReadResponse.json +++ b/codex-rs/app-server-protocol/schema/json/v2/ConfigReadResponse.json @@ -352,6 +352,16 @@ "type": "null" } ] + }, + "web_search_config": { + "anyOf": [ + { + "$ref": "#/definitions/WebSearchConfig" + }, + { + "type": "null" + } + ] } }, "type": "object" @@ -637,6 +647,16 @@ "type": "null" } ] + }, + "web_search_config": { + "anyOf": [ + { + "$ref": "#/definitions/WebSearchConfig" + }, + { + "type": "null" + } + ] } }, "type": "object" @@ -738,6 +758,65 @@ ], "type": "string" }, + "WebSearchConfig": { + "additionalProperties": false, + "properties": { + "filters": { + "anyOf": [ + { + "$ref": "#/definitions/WebSearchFilters" + }, + { + "type": "null" + } + ] + }, + "search_context_size": { + "anyOf": [ + { + "$ref": "#/definitions/WebSearchContextSize" + }, + { + "type": "null" + } + ] + }, + "user_location": { + "anyOf": [ + { + "$ref": "#/definitions/WebSearchUserLocation" + }, + { + "type": "null" + } + ] + } + }, + "type": "object" + }, + "WebSearchContextSize": { + "enum": [ + "low", + "medium", + "high" + ], + "type": "string" + }, + "WebSearchFilters": { + "additionalProperties": false, + "properties": { + "allowed_domains": { + "items": { + "type": "string" + }, + "type": [ + "array", + "null" + ] + } + }, + "type": "object" + }, "WebSearchMode": { "enum": [ "disabled", @@ -745,6 +824,50 @@ "live" ], "type": "string" + }, + "WebSearchUserLocation": { + "additionalProperties": false, + "properties": { + "city": { + "type": [ + "string", + "null" + ] + }, + "country": { + "type": [ + "string", + "null" + ] + }, + "region": { + "type": [ + "string", + "null" + ] + }, + "timezone": { + "type": [ + "string", + "null" + ] + }, + "type": { + "allOf": [ + { + "$ref": "#/definitions/WebSearchUserLocationType" + } + ], + "default": "approximate" + } + }, + "type": "object" + }, + "WebSearchUserLocationType": { + "enum": [ + "approximate" + ], + "type": "string" } }, "properties": { diff --git a/codex-rs/app-server-protocol/schema/typescript/WebSearchConfig.ts b/codex-rs/app-server-protocol/schema/typescript/WebSearchConfig.ts new file mode 100644 index 00000000000..b21aad7edbd --- /dev/null +++ b/codex-rs/app-server-protocol/schema/typescript/WebSearchConfig.ts @@ -0,0 +1,8 @@ +// GENERATED CODE! DO NOT MODIFY BY HAND! + +// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. +import type { WebSearchContextSize } from "./WebSearchContextSize"; +import type { WebSearchFilters } from "./WebSearchFilters"; +import type { WebSearchUserLocation } from "./WebSearchUserLocation"; + +export type WebSearchConfig = { filters: WebSearchFilters | null, user_location: WebSearchUserLocation | null, search_context_size: WebSearchContextSize | null, }; diff --git a/codex-rs/app-server-protocol/schema/typescript/WebSearchContextSize.ts b/codex-rs/app-server-protocol/schema/typescript/WebSearchContextSize.ts new file mode 100644 index 00000000000..d6feedde849 --- /dev/null +++ b/codex-rs/app-server-protocol/schema/typescript/WebSearchContextSize.ts @@ -0,0 +1,5 @@ +// GENERATED CODE! DO NOT MODIFY BY HAND! + +// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. + +export type WebSearchContextSize = "low" | "medium" | "high"; diff --git a/codex-rs/app-server-protocol/schema/typescript/WebSearchFilters.ts b/codex-rs/app-server-protocol/schema/typescript/WebSearchFilters.ts new file mode 100644 index 00000000000..16ce24affce --- /dev/null +++ b/codex-rs/app-server-protocol/schema/typescript/WebSearchFilters.ts @@ -0,0 +1,5 @@ +// GENERATED CODE! DO NOT MODIFY BY HAND! + +// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. + +export type WebSearchFilters = { allowed_domains: Array | null, }; diff --git a/codex-rs/app-server-protocol/schema/typescript/WebSearchUserLocation.ts b/codex-rs/app-server-protocol/schema/typescript/WebSearchUserLocation.ts new file mode 100644 index 00000000000..dd103a433ac --- /dev/null +++ b/codex-rs/app-server-protocol/schema/typescript/WebSearchUserLocation.ts @@ -0,0 +1,6 @@ +// GENERATED CODE! DO NOT MODIFY BY HAND! + +// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. +import type { WebSearchUserLocationType } from "./WebSearchUserLocationType"; + +export type WebSearchUserLocation = { type: WebSearchUserLocationType, country: string | null, region: string | null, city: string | null, timezone: string | null, }; diff --git a/codex-rs/app-server-protocol/schema/typescript/WebSearchUserLocationType.ts b/codex-rs/app-server-protocol/schema/typescript/WebSearchUserLocationType.ts new file mode 100644 index 00000000000..103b47c0344 --- /dev/null +++ b/codex-rs/app-server-protocol/schema/typescript/WebSearchUserLocationType.ts @@ -0,0 +1,5 @@ +// GENERATED CODE! DO NOT MODIFY BY HAND! + +// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. + +export type WebSearchUserLocationType = "approximate"; diff --git a/codex-rs/app-server-protocol/schema/typescript/index.ts b/codex-rs/app-server-protocol/schema/typescript/index.ts index 67b98c39467..48ced00832f 100644 --- a/codex-rs/app-server-protocol/schema/typescript/index.ts +++ b/codex-rs/app-server-protocol/schema/typescript/index.ts @@ -211,7 +211,12 @@ export type { ViewImageToolCallEvent } from "./ViewImageToolCallEvent"; export type { WarningEvent } from "./WarningEvent"; export type { WebSearchAction } from "./WebSearchAction"; export type { WebSearchBeginEvent } from "./WebSearchBeginEvent"; +export type { WebSearchConfig } from "./WebSearchConfig"; +export type { WebSearchContextSize } from "./WebSearchContextSize"; export type { WebSearchEndEvent } from "./WebSearchEndEvent"; +export type { WebSearchFilters } from "./WebSearchFilters"; export type { WebSearchItem } from "./WebSearchItem"; export type { WebSearchMode } from "./WebSearchMode"; +export type { WebSearchUserLocation } from "./WebSearchUserLocation"; +export type { WebSearchUserLocationType } from "./WebSearchUserLocationType"; export * as v2 from "./v2"; diff --git a/codex-rs/app-server-protocol/schema/typescript/v2/Config.ts b/codex-rs/app-server-protocol/schema/typescript/v2/Config.ts index fb5d6ecbb93..bad841fd54b 100644 --- a/codex-rs/app-server-protocol/schema/typescript/v2/Config.ts +++ b/codex-rs/app-server-protocol/schema/typescript/v2/Config.ts @@ -6,6 +6,7 @@ import type { ReasoningEffort } from "../ReasoningEffort"; import type { ReasoningSummary } from "../ReasoningSummary"; import type { ServiceTier } from "../ServiceTier"; import type { Verbosity } from "../Verbosity"; +import type { WebSearchConfig } from "../WebSearchConfig"; import type { WebSearchMode } from "../WebSearchMode"; import type { JsonValue } from "../serde_json/JsonValue"; import type { AnalyticsConfig } from "./AnalyticsConfig"; @@ -15,4 +16,4 @@ import type { SandboxMode } from "./SandboxMode"; import type { SandboxWorkspaceWrite } from "./SandboxWorkspaceWrite"; import type { ToolsV2 } from "./ToolsV2"; -export type Config = {model: string | null, review_model: string | null, model_context_window: bigint | null, model_auto_compact_token_limit: bigint | null, model_provider: string | null, approval_policy: AskForApproval | null, sandbox_mode: SandboxMode | null, sandbox_workspace_write: SandboxWorkspaceWrite | null, forced_chatgpt_workspace_id: string | null, forced_login_method: ForcedLoginMethod | null, web_search: WebSearchMode | null, tools: ToolsV2 | null, profile: string | null, profiles: { [key in string]?: ProfileV2 }, instructions: string | null, developer_instructions: string | null, compact_prompt: string | null, model_reasoning_effort: ReasoningEffort | null, model_reasoning_summary: ReasoningSummary | null, model_verbosity: Verbosity | null, service_tier: ServiceTier | null, analytics: AnalyticsConfig | null} & ({ [key in string]?: number | string | boolean | Array | { [key in string]?: JsonValue } | null }); +export type Config = {model: string | null, review_model: string | null, model_context_window: bigint | null, model_auto_compact_token_limit: bigint | null, model_provider: string | null, approval_policy: AskForApproval | null, sandbox_mode: SandboxMode | null, sandbox_workspace_write: SandboxWorkspaceWrite | null, forced_chatgpt_workspace_id: string | null, forced_login_method: ForcedLoginMethod | null, web_search: WebSearchMode | null, web_search_config: WebSearchConfig | null, tools: ToolsV2 | null, profile: string | null, profiles: { [key in string]?: ProfileV2 }, instructions: string | null, developer_instructions: string | null, compact_prompt: string | null, model_reasoning_effort: ReasoningEffort | null, model_reasoning_summary: ReasoningSummary | null, model_verbosity: Verbosity | null, service_tier: ServiceTier | null, analytics: AnalyticsConfig | null} & ({ [key in string]?: number | string | boolean | Array | { [key in string]?: JsonValue } | null }); diff --git a/codex-rs/app-server-protocol/schema/typescript/v2/ProfileV2.ts b/codex-rs/app-server-protocol/schema/typescript/v2/ProfileV2.ts index 81d20993cbf..34739c8c7e5 100644 --- a/codex-rs/app-server-protocol/schema/typescript/v2/ProfileV2.ts +++ b/codex-rs/app-server-protocol/schema/typescript/v2/ProfileV2.ts @@ -5,8 +5,9 @@ import type { ReasoningEffort } from "../ReasoningEffort"; import type { ReasoningSummary } from "../ReasoningSummary"; import type { ServiceTier } from "../ServiceTier"; import type { Verbosity } from "../Verbosity"; +import type { WebSearchConfig } from "../WebSearchConfig"; import type { WebSearchMode } from "../WebSearchMode"; import type { JsonValue } from "../serde_json/JsonValue"; import type { AskForApproval } from "./AskForApproval"; -export type ProfileV2 = { model: string | null, model_provider: string | null, approval_policy: AskForApproval | null, service_tier: ServiceTier | null, model_reasoning_effort: ReasoningEffort | null, model_reasoning_summary: ReasoningSummary | null, model_verbosity: Verbosity | null, web_search: WebSearchMode | null, chatgpt_base_url: string | null, } & ({ [key in string]?: number | string | boolean | Array | { [key in string]?: JsonValue } | null }); +export type ProfileV2 = { model: string | null, model_provider: string | null, approval_policy: AskForApproval | null, service_tier: ServiceTier | null, model_reasoning_effort: ReasoningEffort | null, model_reasoning_summary: ReasoningSummary | null, model_verbosity: Verbosity | null, web_search: WebSearchMode | null, web_search_config: WebSearchConfig | null, chatgpt_base_url: string | null, } & ({ [key in string]?: number | string | boolean | Array | { [key in string]?: JsonValue } | null }); diff --git a/codex-rs/app-server-protocol/src/protocol/v2.rs b/codex-rs/app-server-protocol/src/protocol/v2.rs index c65c41d1a72..99fec9031b3 100644 --- a/codex-rs/app-server-protocol/src/protocol/v2.rs +++ b/codex-rs/app-server-protocol/src/protocol/v2.rs @@ -21,6 +21,7 @@ use codex_protocol::config_types::ReasoningSummary; use codex_protocol::config_types::SandboxMode as CoreSandboxMode; use codex_protocol::config_types::ServiceTier; use codex_protocol::config_types::Verbosity; +use codex_protocol::config_types::WebSearchConfig; use codex_protocol::config_types::WebSearchMode; use codex_protocol::items::AgentMessageContent as CoreAgentMessageContent; use codex_protocol::items::TurnItem as CoreTurnItem; @@ -401,6 +402,7 @@ pub struct ProfileV2 { pub model_reasoning_summary: Option, pub model_verbosity: Option, pub web_search: Option, + pub web_search_config: Option, pub chatgpt_base_url: Option, #[serde(default, flatten)] pub additional: HashMap, @@ -498,6 +500,7 @@ pub struct Config { pub forced_chatgpt_workspace_id: Option, pub forced_login_method: Option, pub web_search: Option, + pub web_search_config: Option, pub tools: Option, pub profile: Option, #[serde(default)] diff --git a/codex-rs/app-server/tests/suite/v2/config_rpc.rs b/codex-rs/app-server/tests/suite/v2/config_rpc.rs index cd74710876b..0033972fbc8 100644 --- a/codex-rs/app-server/tests/suite/v2/config_rpc.rs +++ b/codex-rs/app-server/tests/suite/v2/config_rpc.rs @@ -23,6 +23,11 @@ use codex_app_server_protocol::ToolsV2; use codex_app_server_protocol::WriteStatus; use codex_core::config::set_project_trust_level; use codex_protocol::config_types::TrustLevel; +use codex_protocol::config_types::WebSearchConfig; +use codex_protocol::config_types::WebSearchContextSize; +use codex_protocol::config_types::WebSearchFilters; +use codex_protocol::config_types::WebSearchUserLocation; +use codex_protocol::config_types::WebSearchUserLocationType; use codex_protocol::openai_models::ReasoningEffort; use codex_utils_absolute_path::AbsolutePathBuf; use pretty_assertions::assert_eq; @@ -148,6 +153,63 @@ view_image = false Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn config_read_includes_web_search_config() -> Result<()> { + let codex_home = TempDir::new()?; + write_config( + &codex_home, + r#" +web_search = "live" + +[web_search_config] +search_context_size = "high" + +[web_search_config.filters] +allowed_domains = ["example.com"] + +[web_search_config.user_location] +country = "US" +city = "New York" +timezone = "America/New_York" +"#, + )?; + + let mut mcp = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + + let request_id = mcp + .send_config_read_request(ConfigReadParams { + include_layers: false, + cwd: None, + }) + .await?; + let resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(request_id)), + ) + .await??; + let ConfigReadResponse { config, .. } = to_response(resp)?; + + assert_eq!( + config.web_search_config, + Some(WebSearchConfig { + filters: Some(WebSearchFilters { + allowed_domains: Some(vec!["example.com".to_string()]), + }), + user_location: Some(WebSearchUserLocation { + r#type: WebSearchUserLocationType::Approximate, + country: Some("US".to_string()), + region: None, + city: Some("New York".to_string()), + timezone: Some("America/New_York".to_string()), + }), + search_context_size: Some(WebSearchContextSize::High), + }) + ); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn config_read_includes_apps() -> Result<()> { let codex_home = TempDir::new()?; diff --git a/codex-rs/core/config.schema.json b/codex-rs/core/config.schema.json index 1f94469530e..0eb3f31c78e 100644 --- a/codex-rs/core/config.schema.json +++ b/codex-rs/core/config.schema.json @@ -539,6 +539,9 @@ "web_search": { "$ref": "#/definitions/WebSearchMode" }, + "web_search_config": { + "$ref": "#/definitions/WebSearchConfig" + }, "windows": { "allOf": [ { @@ -1548,6 +1551,41 @@ ], "type": "string" }, + "WebSearchConfig": { + "additionalProperties": false, + "properties": { + "filters": { + "$ref": "#/definitions/WebSearchFilters" + }, + "search_context_size": { + "$ref": "#/definitions/WebSearchContextSize" + }, + "user_location": { + "$ref": "#/definitions/WebSearchUserLocation" + } + }, + "type": "object" + }, + "WebSearchContextSize": { + "enum": [ + "low", + "medium", + "high" + ], + "type": "string" + }, + "WebSearchFilters": { + "additionalProperties": false, + "properties": { + "allowed_domains": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "type": "object" + }, "WebSearchMode": { "enum": [ "disabled", @@ -1556,6 +1594,38 @@ ], "type": "string" }, + "WebSearchUserLocation": { + "additionalProperties": false, + "properties": { + "city": { + "type": "string" + }, + "country": { + "type": "string" + }, + "region": { + "type": "string" + }, + "timezone": { + "type": "string" + }, + "type": { + "allOf": [ + { + "$ref": "#/definitions/WebSearchUserLocationType" + } + ], + "default": "approximate" + } + }, + "type": "object" + }, + "WebSearchUserLocationType": { + "enum": [ + "approximate" + ], + "type": "string" + }, "WindowsSandboxModeToml": { "enum": [ "elevated", @@ -2224,6 +2294,14 @@ ], "description": "Controls the web search tool mode: disabled, cached, or live." }, + "web_search_config": { + "allOf": [ + { + "$ref": "#/definitions/WebSearchConfig" + } + ], + "description": "Optional structured configuration for the web search tool." + }, "windows": { "allOf": [ { diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index b2bd4d0b3aa..1fb3b7b47c8 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -154,8 +154,13 @@ fn strip_total_output_header(output: &str) -> Option<(&str, u32)> { pub(crate) mod tools { use crate::tools::spec::JsonSchema; + use codex_protocol::config_types::WebSearchContextSize; + use codex_protocol::config_types::WebSearchFilters; + use codex_protocol::config_types::WebSearchUserLocation; + use codex_protocol::config_types::WebSearchUserLocationType; use serde::Deserialize; use serde::Serialize; + use serde::Serializer; /// When serialized as JSON, this produces a valid "Tool" in the OpenAI /// Responses API. @@ -176,6 +181,18 @@ pub(crate) mod tools { WebSearch { #[serde(skip_serializing_if = "Option::is_none")] external_web_access: Option, + #[serde( + skip_serializing_if = "Option::is_none", + serialize_with = "serialize_web_search_filters" + )] + filters: Option, + #[serde( + skip_serializing_if = "Option::is_none", + serialize_with = "serialize_web_search_user_location" + )] + user_location: Option, + #[serde(skip_serializing_if = "Option::is_none")] + search_context_size: Option, #[serde(skip_serializing_if = "Option::is_none")] search_content_types: Option>, }, @@ -195,6 +212,66 @@ pub(crate) mod tools { } } + fn serialize_web_search_filters( + filters: &Option, + serializer: S, + ) -> Result + where + S: Serializer, + { + match filters { + Some(filters) => { + #[derive(Serialize)] + struct SerializableWebSearchFilters<'a> { + #[serde(skip_serializing_if = "Option::is_none")] + allowed_domains: Option<&'a Vec>, + } + + SerializableWebSearchFilters { + allowed_domains: filters.allowed_domains.as_ref(), + } + .serialize(serializer) + } + None => serializer.serialize_none(), + } + } + + fn serialize_web_search_user_location( + user_location: &Option, + serializer: S, + ) -> Result + where + S: Serializer, + { + match user_location { + Some(user_location) => { + #[derive(Serialize)] + struct SerializableWebSearchUserLocation<'a> { + #[serde(rename = "type")] + r#type: WebSearchUserLocationType, + #[serde(skip_serializing_if = "Option::is_none")] + country: Option<&'a String>, + #[serde(skip_serializing_if = "Option::is_none")] + region: Option<&'a String>, + #[serde(skip_serializing_if = "Option::is_none")] + city: Option<&'a String>, + #[serde(skip_serializing_if = "Option::is_none")] + timezone: Option<&'a String>, + } + + SerializableWebSearchUserLocation { + r#type: user_location.r#type, + country: user_location.country.as_ref(), + region: user_location.region.as_ref(), + city: user_location.city.as_ref(), + timezone: user_location.timezone.as_ref(), + } + .serialize(serializer) + } + None => serializer.serialize_none(), + } + } + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct FreeformTool { pub(crate) name: String, diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index bd32dfe0296..44d365233d6 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -744,6 +744,7 @@ impl TurnContext { web_search_mode: self.tools_config.web_search_mode, session_source: self.session_source.clone(), }) + .with_web_search_config(self.tools_config.web_search_config.clone()) .with_allow_login_shell(self.tools_config.allow_login_shell) .with_agent_roles(config.agent_roles.clone()); @@ -1119,6 +1120,7 @@ impl Session { web_search_mode: Some(per_turn_config.web_search_mode.value()), session_source: session_source.clone(), }) + .with_web_search_config(per_turn_config.web_search_config.clone()) .with_allow_login_shell(per_turn_config.permissions.allow_login_shell) .with_agent_roles(per_turn_config.agent_roles.clone()); @@ -4911,6 +4913,7 @@ async fn spawn_review_thread( web_search_mode: Some(review_web_search_mode), session_source: parent_turn_context.session_source.clone(), }) + .with_web_search_config(None) .with_allow_login_shell(config.permissions.allow_login_shell) .with_agent_roles(config.agent_roles.clone()); diff --git a/codex-rs/core/src/config/mod.rs b/codex-rs/core/src/config/mod.rs index e489720bab9..6ea0cc1dc64 100644 --- a/codex-rs/core/src/config/mod.rs +++ b/codex-rs/core/src/config/mod.rs @@ -67,7 +67,10 @@ use codex_protocol::config_types::SandboxMode; use codex_protocol::config_types::ServiceTier; use codex_protocol::config_types::TrustLevel; use codex_protocol::config_types::Verbosity; +use codex_protocol::config_types::WebSearchConfig; +use codex_protocol::config_types::WebSearchFilters; use codex_protocol::config_types::WebSearchMode; +use codex_protocol::config_types::WebSearchUserLocation; use codex_protocol::config_types::WindowsSandboxLevel; use codex_protocol::models::MacOsSeatbeltProfileExtensions; use codex_protocol::openai_models::ModelsResponse; @@ -465,6 +468,9 @@ pub struct Config { /// Explicit or feature-derived web search mode. pub web_search_mode: Constrained, + /// Additional parameters for the web search tool when it is enabled. + pub web_search_config: Option, + /// If set to `true`, used only the experimental unified exec tool. pub use_experimental_unified_exec_tool: bool, @@ -1216,6 +1222,9 @@ pub struct ConfigToml { /// Controls the web search tool mode: disabled, cached, or live. pub web_search: Option, + /// Optional structured configuration for the web search tool. + pub web_search_config: Option, + /// Nested tools section for feature toggles pub tools: Option, @@ -1637,6 +1646,60 @@ fn resolve_web_search_mode( None } +fn resolve_web_search_config( + config_toml: &ConfigToml, + config_profile: &ConfigProfile, +) -> Option { + let base = config_toml.web_search_config.as_ref(); + let profile = config_profile.web_search_config.as_ref(); + + match (base, profile) { + (None, None) => None, + (Some(base), None) => Some(base.clone()), + (None, Some(profile)) => Some(profile.clone()), + (Some(base), Some(profile)) => Some(WebSearchConfig { + filters: match (base.filters.as_ref(), profile.filters.as_ref()) { + (None, None) => None, + (Some(base_filters), None) => Some(base_filters.clone()), + (None, Some(profile_filters)) => Some(profile_filters.clone()), + (Some(base_filters), Some(profile_filters)) => Some(WebSearchFilters { + allowed_domains: profile_filters + .allowed_domains + .clone() + .or_else(|| base_filters.allowed_domains.clone()), + }), + }, + user_location: match (base.user_location.as_ref(), profile.user_location.as_ref()) { + (None, None) => None, + (Some(base_user_location), None) => Some(base_user_location.clone()), + (None, Some(profile_user_location)) => Some(profile_user_location.clone()), + (Some(base_user_location), Some(profile_user_location)) => { + Some(WebSearchUserLocation { + r#type: profile_user_location.r#type, + country: profile_user_location + .country + .clone() + .or_else(|| base_user_location.country.clone()), + region: profile_user_location + .region + .clone() + .or_else(|| base_user_location.region.clone()), + city: profile_user_location + .city + .clone() + .or_else(|| base_user_location.city.clone()), + timezone: profile_user_location + .timezone + .clone() + .or_else(|| base_user_location.timezone.clone()), + }) + } + }, + search_context_size: profile.search_context_size.or(base.search_context_size), + }), + } +} + pub(crate) fn resolve_web_search_mode_for_turn( web_search_mode: &Constrained, sandbox_policy: &SandboxPolicy, @@ -1836,6 +1899,11 @@ impl Config { } let web_search_mode = resolve_web_search_mode(&cfg, &config_profile, &features) .unwrap_or(WebSearchMode::Cached); + let web_search_config = resolve_web_search_config(&cfg, &config_profile); + // TODO(dylan): We should be able to leverage ConfigLayerStack so that + // we can reliably check this at every config level. + let did_user_set_custom_approval_policy_or_sandbox_mode = + approval_policy_was_explicit || sandbox_mode_was_explicit; let mut model_providers = built_in_model_providers(); // Merge user-defined providers into the built-in list. @@ -2224,6 +2292,7 @@ impl Config { forced_login_method, include_apply_patch_tool: include_apply_patch_tool_flag, web_search_mode: constrained_web_search_mode.value, + web_search_config, use_experimental_unified_exec_tool, background_terminal_max_timeout, ghost_snapshot, diff --git a/codex-rs/core/src/config/profile.rs b/codex-rs/core/src/config/profile.rs index 6d4cd230901..f2867b25391 100644 --- a/codex-rs/core/src/config/profile.rs +++ b/codex-rs/core/src/config/profile.rs @@ -10,6 +10,7 @@ use codex_protocol::config_types::ReasoningSummary; use codex_protocol::config_types::SandboxMode; use codex_protocol::config_types::ServiceTier; use codex_protocol::config_types::Verbosity; +use codex_protocol::config_types::WebSearchConfig; use codex_protocol::config_types::WebSearchMode; use codex_protocol::openai_models::ReasoningEffort; @@ -50,6 +51,7 @@ pub struct ConfigProfile { pub tools_web_search: Option, pub tools_view_image: Option, pub web_search: Option, + pub web_search_config: Option, pub analytics: Option, #[serde(default)] pub windows: Option, diff --git a/codex-rs/core/src/tools/spec.rs b/codex-rs/core/src/tools/spec.rs index f0c952e0787..9c6cb95a648 100644 --- a/codex-rs/core/src/tools/spec.rs +++ b/codex-rs/core/src/tools/spec.rs @@ -18,6 +18,7 @@ use crate::tools::handlers::multi_agents::MAX_WAIT_TIMEOUT_MS; use crate::tools::handlers::multi_agents::MIN_WAIT_TIMEOUT_MS; use crate::tools::handlers::request_user_input_tool_description; use crate::tools::registry::ToolRegistryBuilder; +use codex_protocol::config_types::WebSearchConfig; use codex_protocol::config_types::WebSearchMode; use codex_protocol::dynamic_tools::DynamicToolSpec; use codex_protocol::models::VIEW_IMAGE_TOOL_NAME; @@ -58,6 +59,7 @@ pub(crate) struct ToolsConfig { pub allow_login_shell: bool, pub apply_patch_tool_type: Option, pub web_search_mode: Option, + pub web_search_config: Option, pub web_search_tool_type: WebSearchToolType, pub image_gen_tool: bool, pub agent_roles: BTreeMap, @@ -158,6 +160,7 @@ impl ToolsConfig { allow_login_shell: true, apply_patch_tool_type, web_search_mode: *web_search_mode, + web_search_config: None, web_search_tool_type: model_info.web_search_tool_type, image_gen_tool: include_image_gen_tool, agent_roles: BTreeMap::new(), @@ -184,6 +187,11 @@ impl ToolsConfig { self.allow_login_shell = allow_login_shell; self } + + pub fn with_web_search_config(mut self, web_search_config: Option) -> Self { + self.web_search_config = web_search_config; + self + } } fn supports_image_generation(model_info: &ModelInfo) -> bool { @@ -1979,6 +1987,18 @@ pub(crate) fn build_specs( builder.push_spec(ToolSpec::WebSearch { external_web_access: Some(external_web_access), + filters: config + .web_search_config + .as_ref() + .and_then(|cfg| cfg.filters.clone()), + user_location: config + .web_search_config + .as_ref() + .and_then(|cfg| cfg.user_location.clone()), + search_context_size: config + .web_search_config + .as_ref() + .and_then(|cfg| cfg.search_context_size), search_content_types, }); } @@ -2266,6 +2286,9 @@ mod tests { create_apply_patch_freeform_tool(), ToolSpec::WebSearch { external_web_access: Some(true), + filters: None, + user_location: None, + search_context_size: None, search_content_types: None, }, create_view_image_tool(), @@ -2592,6 +2615,9 @@ mod tests { tool.spec, ToolSpec::WebSearch { external_web_access: Some(false), + filters: None, + user_location: None, + search_context_size: None, search_content_types: None, } ); @@ -2617,6 +2643,51 @@ mod tests { tool.spec, ToolSpec::WebSearch { external_web_access: Some(true), + filters: None, + user_location: None, + search_context_size: None, + search_content_types: None, + } + ); + } + + #[test] + fn web_search_config_is_forwarded_to_tool_spec() { + let config = test_config(); + let model_info = + ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let features = Features::with_defaults(); + let web_search_config = WebSearchConfig { + filters: Some(codex_protocol::config_types::WebSearchFilters { + allowed_domains: Some(vec!["example.com".to_string()]), + }), + user_location: Some(codex_protocol::config_types::WebSearchUserLocation { + r#type: codex_protocol::config_types::WebSearchUserLocationType::Approximate, + country: Some("US".to_string()), + region: Some("California".to_string()), + city: Some("San Francisco".to_string()), + timezone: Some("America/Los_Angeles".to_string()), + }), + search_context_size: Some(codex_protocol::config_types::WebSearchContextSize::High), + }; + + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + features: &features, + web_search_mode: Some(WebSearchMode::Live), + session_source: SessionSource::Cli, + }) + .with_web_search_config(Some(web_search_config.clone())); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + + let tool = find_tool(&tools, "web_search"); + assert_eq!( + tool.spec, + ToolSpec::WebSearch { + external_web_access: Some(true), + filters: web_search_config.filters, + user_location: web_search_config.user_location, + search_context_size: web_search_config.search_context_size, search_content_types: None, } ); @@ -2643,6 +2714,9 @@ mod tests { tool.spec, ToolSpec::WebSearch { external_web_access: Some(true), + filters: None, + user_location: None, + search_context_size: None, search_content_types: Some( WEB_SEARCH_CONTENT_TYPES .into_iter() diff --git a/codex-rs/protocol/src/config_types.rs b/codex-rs/protocol/src/config_types.rs index b467d34857d..b4a964bbde7 100644 --- a/codex-rs/protocol/src/config_types.rs +++ b/codex-rs/protocol/src/config_types.rs @@ -113,6 +113,50 @@ pub enum WebSearchMode { Live, } +#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Display, JsonSchema, TS)] +#[serde(rename_all = "lowercase")] +#[strum(serialize_all = "lowercase")] +pub enum WebSearchContextSize { + Low, + Medium, + High, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq, Eq, JsonSchema, TS)] +#[schemars(deny_unknown_fields)] +pub struct WebSearchFilters { + pub allowed_domains: Option>, +} + +#[derive( + Debug, Serialize, Deserialize, Clone, Copy, Default, PartialEq, Eq, Display, JsonSchema, TS, +)] +#[serde(rename_all = "lowercase")] +#[strum(serialize_all = "lowercase")] +pub enum WebSearchUserLocationType { + #[default] + Approximate, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq, Eq, JsonSchema, TS)] +#[schemars(deny_unknown_fields)] +pub struct WebSearchUserLocation { + #[serde(default)] + pub r#type: WebSearchUserLocationType, + pub country: Option, + pub region: Option, + pub city: Option, + pub timezone: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq, Eq, JsonSchema, TS)] +#[schemars(deny_unknown_fields)] +pub struct WebSearchConfig { + pub filters: Option, + pub user_location: Option, + pub search_context_size: Option, +} + #[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Display, JsonSchema, TS)] #[serde(rename_all = "lowercase")] #[strum(serialize_all = "lowercase")] From d490664e4d78d61d957bc81bb82ffc66f8b2b6d5 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Fri, 6 Mar 2026 13:51:29 -0500 Subject: [PATCH 2/7] Move web search config under tools.web_search --- .../codex_app_server_protocol.schemas.json | 138 +- .../codex_app_server_protocol.v2.schemas.json | 138 +- .../schema/json/v2/ConfigReadResponse.json | 138 +- .../schema/typescript/WebSearchConfig.ts | 8 - ...bSearchFilters.ts => WebSearchLocation.ts} | 2 +- .../schema/typescript/WebSearchToolConfig.ts | 7 + .../typescript/WebSearchUserLocation.ts | 6 - .../typescript/WebSearchUserLocationType.ts | 5 - .../schema/typescript/index.ts | 6 +- .../schema/typescript/v2/Config.ts | 3 +- .../schema/typescript/v2/ProfileV2.ts | 4 +- .../schema/typescript/v2/ToolsV2.ts | 3 +- .../app-server-protocol/src/protocol/v1.rs | 4 +- .../app-server-protocol/src/protocol/v2.rs | 8 +- .../app-server/tests/suite/v2/config_rpc.rs | 60 +- codex-rs/core/config.schema.json | 92 +- codex-rs/core/src/codex.rs | 3874 +++++++++++++++++ codex-rs/core/src/config/managed_features.rs | 15 - codex-rs/core/src/config/mod.rs | 86 +- codex-rs/core/src/config/profile.rs | 6 +- codex-rs/core/src/features.rs | 12 +- codex-rs/core/src/features/legacy.rs | 7 - codex-rs/core/tests/suite/web_search.rs | 60 + codex-rs/protocol/src/config_types.rs | 43 + 24 files changed, 4266 insertions(+), 459 deletions(-) delete mode 100644 codex-rs/app-server-protocol/schema/typescript/WebSearchConfig.ts rename codex-rs/app-server-protocol/schema/typescript/{WebSearchFilters.ts => WebSearchLocation.ts} (54%) create mode 100644 codex-rs/app-server-protocol/schema/typescript/WebSearchToolConfig.ts delete mode 100644 codex-rs/app-server-protocol/schema/typescript/WebSearchUserLocation.ts delete mode 100644 codex-rs/app-server-protocol/schema/typescript/WebSearchUserLocationType.ts diff --git a/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.schemas.json b/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.schemas.json index debeacd2de1..642057cbf7f 100644 --- a/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.schemas.json +++ b/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.schemas.json @@ -9595,16 +9595,6 @@ "type": "null" } ] - }, - "web_search_config": { - "anyOf": [ - { - "$ref": "#/definitions/v2/WebSearchConfig" - }, - { - "type": "null" - } - ] } }, "type": "object" @@ -11999,20 +11989,20 @@ } ] }, - "web_search": { + "tools": { "anyOf": [ { - "$ref": "#/definitions/v2/WebSearchMode" + "$ref": "#/definitions/v2/ToolsV2" }, { "type": "null" } ] }, - "web_search_config": { + "web_search": { "anyOf": [ { - "$ref": "#/definitions/v2/WebSearchConfig" + "$ref": "#/definitions/v2/WebSearchMode" }, { "type": "null" @@ -15749,9 +15739,13 @@ ] }, "web_search": { - "type": [ - "boolean", - "null" + "anyOf": [ + { + "$ref": "#/definitions/v2/WebSearchToolConfig" + }, + { + "type": "null" + } ] } }, @@ -16350,42 +16344,6 @@ } ] }, - "WebSearchConfig": { - "additionalProperties": false, - "properties": { - "filters": { - "anyOf": [ - { - "$ref": "#/definitions/v2/WebSearchFilters" - }, - { - "type": "null" - } - ] - }, - "search_context_size": { - "anyOf": [ - { - "$ref": "#/definitions/v2/WebSearchContextSize" - }, - { - "type": "null" - } - ] - }, - "user_location": { - "anyOf": [ - { - "$ref": "#/definitions/v2/WebSearchUserLocation" - }, - { - "type": "null" - } - ] - } - }, - "type": "object" - }, "WebSearchContextSize": { "enum": [ "low", @@ -16394,30 +16352,7 @@ ], "type": "string" }, - "WebSearchFilters": { - "additionalProperties": false, - "properties": { - "allowed_domains": { - "items": { - "type": "string" - }, - "type": [ - "array", - "null" - ] - } - }, - "type": "object" - }, - "WebSearchMode": { - "enum": [ - "disabled", - "cached", - "live" - ], - "type": "string" - }, - "WebSearchUserLocation": { + "WebSearchLocation": { "additionalProperties": false, "properties": { "city": { @@ -16443,24 +16378,53 @@ "string", "null" ] - }, - "type": { - "allOf": [ - { - "$ref": "#/definitions/v2/WebSearchUserLocationType" - } - ], - "default": "approximate" } }, "type": "object" }, - "WebSearchUserLocationType": { + "WebSearchMode": { "enum": [ - "approximate" + "disabled", + "cached", + "live" ], "type": "string" }, + "WebSearchToolConfig": { + "additionalProperties": false, + "properties": { + "allowed_domains": { + "items": { + "type": "string" + }, + "type": [ + "array", + "null" + ] + }, + "context_size": { + "anyOf": [ + { + "$ref": "#/definitions/v2/WebSearchContextSize" + }, + { + "type": "null" + } + ] + }, + "location": { + "anyOf": [ + { + "$ref": "#/definitions/v2/WebSearchLocation" + }, + { + "type": "null" + } + ] + } + }, + "type": "object" + }, "WindowsSandboxSetupCompletedNotification": { "$schema": "http://json-schema.org/draft-07/schema#", "properties": { diff --git a/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.v2.schemas.json b/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.v2.schemas.json index cc338997bec..8b8b1be8150 100644 --- a/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.v2.schemas.json +++ b/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.v2.schemas.json @@ -2635,16 +2635,6 @@ "type": "null" } ] - }, - "web_search_config": { - "anyOf": [ - { - "$ref": "#/definitions/WebSearchConfig" - }, - { - "type": "null" - } - ] } }, "type": "object" @@ -8656,20 +8646,20 @@ } ] }, - "web_search": { + "tools": { "anyOf": [ { - "$ref": "#/definitions/WebSearchMode" + "$ref": "#/definitions/ToolsV2" }, { "type": "null" } ] }, - "web_search_config": { + "web_search": { "anyOf": [ { - "$ref": "#/definitions/WebSearchConfig" + "$ref": "#/definitions/WebSearchMode" }, { "type": "null" @@ -13781,9 +13771,13 @@ ] }, "web_search": { - "type": [ - "boolean", - "null" + "anyOf": [ + { + "$ref": "#/definitions/WebSearchToolConfig" + }, + { + "type": "null" + } ] } }, @@ -14606,42 +14600,6 @@ } ] }, - "WebSearchConfig": { - "additionalProperties": false, - "properties": { - "filters": { - "anyOf": [ - { - "$ref": "#/definitions/WebSearchFilters" - }, - { - "type": "null" - } - ] - }, - "search_context_size": { - "anyOf": [ - { - "$ref": "#/definitions/WebSearchContextSize" - }, - { - "type": "null" - } - ] - }, - "user_location": { - "anyOf": [ - { - "$ref": "#/definitions/WebSearchUserLocation" - }, - { - "type": "null" - } - ] - } - }, - "type": "object" - }, "WebSearchContextSize": { "enum": [ "low", @@ -14650,30 +14608,7 @@ ], "type": "string" }, - "WebSearchFilters": { - "additionalProperties": false, - "properties": { - "allowed_domains": { - "items": { - "type": "string" - }, - "type": [ - "array", - "null" - ] - } - }, - "type": "object" - }, - "WebSearchMode": { - "enum": [ - "disabled", - "cached", - "live" - ], - "type": "string" - }, - "WebSearchUserLocation": { + "WebSearchLocation": { "additionalProperties": false, "properties": { "city": { @@ -14699,24 +14634,53 @@ "string", "null" ] - }, - "type": { - "allOf": [ - { - "$ref": "#/definitions/WebSearchUserLocationType" - } - ], - "default": "approximate" } }, "type": "object" }, - "WebSearchUserLocationType": { + "WebSearchMode": { "enum": [ - "approximate" + "disabled", + "cached", + "live" ], "type": "string" }, + "WebSearchToolConfig": { + "additionalProperties": false, + "properties": { + "allowed_domains": { + "items": { + "type": "string" + }, + "type": [ + "array", + "null" + ] + }, + "context_size": { + "anyOf": [ + { + "$ref": "#/definitions/WebSearchContextSize" + }, + { + "type": "null" + } + ] + }, + "location": { + "anyOf": [ + { + "$ref": "#/definitions/WebSearchLocation" + }, + { + "type": "null" + } + ] + } + }, + "type": "object" + }, "WindowsSandboxSetupCompletedNotification": { "$schema": "http://json-schema.org/draft-07/schema#", "properties": { diff --git a/codex-rs/app-server-protocol/schema/json/v2/ConfigReadResponse.json b/codex-rs/app-server-protocol/schema/json/v2/ConfigReadResponse.json index a6eb15cb565..90828da0bae 100644 --- a/codex-rs/app-server-protocol/schema/json/v2/ConfigReadResponse.json +++ b/codex-rs/app-server-protocol/schema/json/v2/ConfigReadResponse.json @@ -352,16 +352,6 @@ "type": "null" } ] - }, - "web_search_config": { - "anyOf": [ - { - "$ref": "#/definitions/WebSearchConfig" - }, - { - "type": "null" - } - ] } }, "type": "object" @@ -638,20 +628,20 @@ } ] }, - "web_search": { + "tools": { "anyOf": [ { - "$ref": "#/definitions/WebSearchMode" + "$ref": "#/definitions/ToolsV2" }, { "type": "null" } ] }, - "web_search_config": { + "web_search": { "anyOf": [ { - "$ref": "#/definitions/WebSearchConfig" + "$ref": "#/definitions/WebSearchMode" }, { "type": "null" @@ -741,50 +731,9 @@ ] }, "web_search": { - "type": [ - "boolean", - "null" - ] - } - }, - "type": "object" - }, - "Verbosity": { - "description": "Controls output length/detail on GPT-5 models via the Responses API. Serialized with lowercase values to match the OpenAI API.", - "enum": [ - "low", - "medium", - "high" - ], - "type": "string" - }, - "WebSearchConfig": { - "additionalProperties": false, - "properties": { - "filters": { - "anyOf": [ - { - "$ref": "#/definitions/WebSearchFilters" - }, - { - "type": "null" - } - ] - }, - "search_context_size": { - "anyOf": [ - { - "$ref": "#/definitions/WebSearchContextSize" - }, - { - "type": "null" - } - ] - }, - "user_location": { "anyOf": [ { - "$ref": "#/definitions/WebSearchUserLocation" + "$ref": "#/definitions/WebSearchToolConfig" }, { "type": "null" @@ -794,7 +743,8 @@ }, "type": "object" }, - "WebSearchContextSize": { + "Verbosity": { + "description": "Controls output length/detail on GPT-5 models via the Responses API. Serialized with lowercase values to match the OpenAI API.", "enum": [ "low", "medium", @@ -802,30 +752,15 @@ ], "type": "string" }, - "WebSearchFilters": { - "additionalProperties": false, - "properties": { - "allowed_domains": { - "items": { - "type": "string" - }, - "type": [ - "array", - "null" - ] - } - }, - "type": "object" - }, - "WebSearchMode": { + "WebSearchContextSize": { "enum": [ - "disabled", - "cached", - "live" + "low", + "medium", + "high" ], "type": "string" }, - "WebSearchUserLocation": { + "WebSearchLocation": { "additionalProperties": false, "properties": { "city": { @@ -851,23 +786,52 @@ "string", "null" ] - }, - "type": { - "allOf": [ - { - "$ref": "#/definitions/WebSearchUserLocationType" - } - ], - "default": "approximate" } }, "type": "object" }, - "WebSearchUserLocationType": { + "WebSearchMode": { "enum": [ - "approximate" + "disabled", + "cached", + "live" ], "type": "string" + }, + "WebSearchToolConfig": { + "additionalProperties": false, + "properties": { + "allowed_domains": { + "items": { + "type": "string" + }, + "type": [ + "array", + "null" + ] + }, + "context_size": { + "anyOf": [ + { + "$ref": "#/definitions/WebSearchContextSize" + }, + { + "type": "null" + } + ] + }, + "location": { + "anyOf": [ + { + "$ref": "#/definitions/WebSearchLocation" + }, + { + "type": "null" + } + ] + } + }, + "type": "object" } }, "properties": { diff --git a/codex-rs/app-server-protocol/schema/typescript/WebSearchConfig.ts b/codex-rs/app-server-protocol/schema/typescript/WebSearchConfig.ts deleted file mode 100644 index b21aad7edbd..00000000000 --- a/codex-rs/app-server-protocol/schema/typescript/WebSearchConfig.ts +++ /dev/null @@ -1,8 +0,0 @@ -// GENERATED CODE! DO NOT MODIFY BY HAND! - -// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. -import type { WebSearchContextSize } from "./WebSearchContextSize"; -import type { WebSearchFilters } from "./WebSearchFilters"; -import type { WebSearchUserLocation } from "./WebSearchUserLocation"; - -export type WebSearchConfig = { filters: WebSearchFilters | null, user_location: WebSearchUserLocation | null, search_context_size: WebSearchContextSize | null, }; diff --git a/codex-rs/app-server-protocol/schema/typescript/WebSearchFilters.ts b/codex-rs/app-server-protocol/schema/typescript/WebSearchLocation.ts similarity index 54% rename from codex-rs/app-server-protocol/schema/typescript/WebSearchFilters.ts rename to codex-rs/app-server-protocol/schema/typescript/WebSearchLocation.ts index 16ce24affce..12319983d7d 100644 --- a/codex-rs/app-server-protocol/schema/typescript/WebSearchFilters.ts +++ b/codex-rs/app-server-protocol/schema/typescript/WebSearchLocation.ts @@ -2,4 +2,4 @@ // This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. -export type WebSearchFilters = { allowed_domains: Array | null, }; +export type WebSearchLocation = { country: string | null, region: string | null, city: string | null, timezone: string | null, }; diff --git a/codex-rs/app-server-protocol/schema/typescript/WebSearchToolConfig.ts b/codex-rs/app-server-protocol/schema/typescript/WebSearchToolConfig.ts new file mode 100644 index 00000000000..c14067cef44 --- /dev/null +++ b/codex-rs/app-server-protocol/schema/typescript/WebSearchToolConfig.ts @@ -0,0 +1,7 @@ +// GENERATED CODE! DO NOT MODIFY BY HAND! + +// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. +import type { WebSearchContextSize } from "./WebSearchContextSize"; +import type { WebSearchLocation } from "./WebSearchLocation"; + +export type WebSearchToolConfig = { context_size: WebSearchContextSize | null, allowed_domains: Array | null, location: WebSearchLocation | null, }; diff --git a/codex-rs/app-server-protocol/schema/typescript/WebSearchUserLocation.ts b/codex-rs/app-server-protocol/schema/typescript/WebSearchUserLocation.ts deleted file mode 100644 index dd103a433ac..00000000000 --- a/codex-rs/app-server-protocol/schema/typescript/WebSearchUserLocation.ts +++ /dev/null @@ -1,6 +0,0 @@ -// GENERATED CODE! DO NOT MODIFY BY HAND! - -// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. -import type { WebSearchUserLocationType } from "./WebSearchUserLocationType"; - -export type WebSearchUserLocation = { type: WebSearchUserLocationType, country: string | null, region: string | null, city: string | null, timezone: string | null, }; diff --git a/codex-rs/app-server-protocol/schema/typescript/WebSearchUserLocationType.ts b/codex-rs/app-server-protocol/schema/typescript/WebSearchUserLocationType.ts deleted file mode 100644 index 103b47c0344..00000000000 --- a/codex-rs/app-server-protocol/schema/typescript/WebSearchUserLocationType.ts +++ /dev/null @@ -1,5 +0,0 @@ -// GENERATED CODE! DO NOT MODIFY BY HAND! - -// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. - -export type WebSearchUserLocationType = "approximate"; diff --git a/codex-rs/app-server-protocol/schema/typescript/index.ts b/codex-rs/app-server-protocol/schema/typescript/index.ts index 48ced00832f..af0db3367e8 100644 --- a/codex-rs/app-server-protocol/schema/typescript/index.ts +++ b/codex-rs/app-server-protocol/schema/typescript/index.ts @@ -211,12 +211,10 @@ export type { ViewImageToolCallEvent } from "./ViewImageToolCallEvent"; export type { WarningEvent } from "./WarningEvent"; export type { WebSearchAction } from "./WebSearchAction"; export type { WebSearchBeginEvent } from "./WebSearchBeginEvent"; -export type { WebSearchConfig } from "./WebSearchConfig"; export type { WebSearchContextSize } from "./WebSearchContextSize"; export type { WebSearchEndEvent } from "./WebSearchEndEvent"; -export type { WebSearchFilters } from "./WebSearchFilters"; export type { WebSearchItem } from "./WebSearchItem"; +export type { WebSearchLocation } from "./WebSearchLocation"; export type { WebSearchMode } from "./WebSearchMode"; -export type { WebSearchUserLocation } from "./WebSearchUserLocation"; -export type { WebSearchUserLocationType } from "./WebSearchUserLocationType"; +export type { WebSearchToolConfig } from "./WebSearchToolConfig"; export * as v2 from "./v2"; diff --git a/codex-rs/app-server-protocol/schema/typescript/v2/Config.ts b/codex-rs/app-server-protocol/schema/typescript/v2/Config.ts index bad841fd54b..fb5d6ecbb93 100644 --- a/codex-rs/app-server-protocol/schema/typescript/v2/Config.ts +++ b/codex-rs/app-server-protocol/schema/typescript/v2/Config.ts @@ -6,7 +6,6 @@ import type { ReasoningEffort } from "../ReasoningEffort"; import type { ReasoningSummary } from "../ReasoningSummary"; import type { ServiceTier } from "../ServiceTier"; import type { Verbosity } from "../Verbosity"; -import type { WebSearchConfig } from "../WebSearchConfig"; import type { WebSearchMode } from "../WebSearchMode"; import type { JsonValue } from "../serde_json/JsonValue"; import type { AnalyticsConfig } from "./AnalyticsConfig"; @@ -16,4 +15,4 @@ import type { SandboxMode } from "./SandboxMode"; import type { SandboxWorkspaceWrite } from "./SandboxWorkspaceWrite"; import type { ToolsV2 } from "./ToolsV2"; -export type Config = {model: string | null, review_model: string | null, model_context_window: bigint | null, model_auto_compact_token_limit: bigint | null, model_provider: string | null, approval_policy: AskForApproval | null, sandbox_mode: SandboxMode | null, sandbox_workspace_write: SandboxWorkspaceWrite | null, forced_chatgpt_workspace_id: string | null, forced_login_method: ForcedLoginMethod | null, web_search: WebSearchMode | null, web_search_config: WebSearchConfig | null, tools: ToolsV2 | null, profile: string | null, profiles: { [key in string]?: ProfileV2 }, instructions: string | null, developer_instructions: string | null, compact_prompt: string | null, model_reasoning_effort: ReasoningEffort | null, model_reasoning_summary: ReasoningSummary | null, model_verbosity: Verbosity | null, service_tier: ServiceTier | null, analytics: AnalyticsConfig | null} & ({ [key in string]?: number | string | boolean | Array | { [key in string]?: JsonValue } | null }); +export type Config = {model: string | null, review_model: string | null, model_context_window: bigint | null, model_auto_compact_token_limit: bigint | null, model_provider: string | null, approval_policy: AskForApproval | null, sandbox_mode: SandboxMode | null, sandbox_workspace_write: SandboxWorkspaceWrite | null, forced_chatgpt_workspace_id: string | null, forced_login_method: ForcedLoginMethod | null, web_search: WebSearchMode | null, tools: ToolsV2 | null, profile: string | null, profiles: { [key in string]?: ProfileV2 }, instructions: string | null, developer_instructions: string | null, compact_prompt: string | null, model_reasoning_effort: ReasoningEffort | null, model_reasoning_summary: ReasoningSummary | null, model_verbosity: Verbosity | null, service_tier: ServiceTier | null, analytics: AnalyticsConfig | null} & ({ [key in string]?: number | string | boolean | Array | { [key in string]?: JsonValue } | null }); diff --git a/codex-rs/app-server-protocol/schema/typescript/v2/ProfileV2.ts b/codex-rs/app-server-protocol/schema/typescript/v2/ProfileV2.ts index 34739c8c7e5..f2c72b3ae65 100644 --- a/codex-rs/app-server-protocol/schema/typescript/v2/ProfileV2.ts +++ b/codex-rs/app-server-protocol/schema/typescript/v2/ProfileV2.ts @@ -5,9 +5,9 @@ import type { ReasoningEffort } from "../ReasoningEffort"; import type { ReasoningSummary } from "../ReasoningSummary"; import type { ServiceTier } from "../ServiceTier"; import type { Verbosity } from "../Verbosity"; -import type { WebSearchConfig } from "../WebSearchConfig"; import type { WebSearchMode } from "../WebSearchMode"; import type { JsonValue } from "../serde_json/JsonValue"; import type { AskForApproval } from "./AskForApproval"; +import type { ToolsV2 } from "./ToolsV2"; -export type ProfileV2 = { model: string | null, model_provider: string | null, approval_policy: AskForApproval | null, service_tier: ServiceTier | null, model_reasoning_effort: ReasoningEffort | null, model_reasoning_summary: ReasoningSummary | null, model_verbosity: Verbosity | null, web_search: WebSearchMode | null, web_search_config: WebSearchConfig | null, chatgpt_base_url: string | null, } & ({ [key in string]?: number | string | boolean | Array | { [key in string]?: JsonValue } | null }); +export type ProfileV2 = { model: string | null, model_provider: string | null, approval_policy: AskForApproval | null, service_tier: ServiceTier | null, model_reasoning_effort: ReasoningEffort | null, model_reasoning_summary: ReasoningSummary | null, model_verbosity: Verbosity | null, web_search: WebSearchMode | null, tools: ToolsV2 | null, chatgpt_base_url: string | null, } & ({ [key in string]?: number | string | boolean | Array | { [key in string]?: JsonValue } | null }); diff --git a/codex-rs/app-server-protocol/schema/typescript/v2/ToolsV2.ts b/codex-rs/app-server-protocol/schema/typescript/v2/ToolsV2.ts index 0b1bee51460..784991f017d 100644 --- a/codex-rs/app-server-protocol/schema/typescript/v2/ToolsV2.ts +++ b/codex-rs/app-server-protocol/schema/typescript/v2/ToolsV2.ts @@ -1,5 +1,6 @@ // GENERATED CODE! DO NOT MODIFY BY HAND! // This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. +import type { WebSearchToolConfig } from "../WebSearchToolConfig"; -export type ToolsV2 = { web_search: boolean | null, view_image: boolean | null, }; +export type ToolsV2 = { web_search: WebSearchToolConfig | null, view_image: boolean | null, }; diff --git a/codex-rs/app-server-protocol/src/protocol/v1.rs b/codex-rs/app-server-protocol/src/protocol/v1.rs index d393f97f72b..c00ec2d5b1c 100644 --- a/codex-rs/app-server-protocol/src/protocol/v1.rs +++ b/codex-rs/app-server-protocol/src/protocol/v1.rs @@ -7,6 +7,7 @@ use codex_protocol::config_types::ReasoningSummary; use codex_protocol::config_types::SandboxMode; use codex_protocol::config_types::ServiceTier; use codex_protocol::config_types::Verbosity; +use codex_protocol::config_types::WebSearchToolConfig; use codex_protocol::models::ResponseItem; use codex_protocol::openai_models::ReasoningEffort; use codex_protocol::parse_command::ParsedCommand; @@ -385,12 +386,13 @@ pub struct Profile { pub model_reasoning_summary: Option, pub model_verbosity: Option, pub chatgpt_base_url: Option, + pub tools: Option, } #[derive(Deserialize, Debug, Clone, PartialEq, Serialize, JsonSchema, TS)] #[serde(rename_all = "camelCase")] pub struct Tools { - pub web_search: Option, + pub web_search: Option, pub view_image: Option, } diff --git a/codex-rs/app-server-protocol/src/protocol/v2.rs b/codex-rs/app-server-protocol/src/protocol/v2.rs index 99fec9031b3..a34ecd5b21c 100644 --- a/codex-rs/app-server-protocol/src/protocol/v2.rs +++ b/codex-rs/app-server-protocol/src/protocol/v2.rs @@ -21,8 +21,8 @@ use codex_protocol::config_types::ReasoningSummary; use codex_protocol::config_types::SandboxMode as CoreSandboxMode; use codex_protocol::config_types::ServiceTier; use codex_protocol::config_types::Verbosity; -use codex_protocol::config_types::WebSearchConfig; use codex_protocol::config_types::WebSearchMode; +use codex_protocol::config_types::WebSearchToolConfig; use codex_protocol::items::AgentMessageContent as CoreAgentMessageContent; use codex_protocol::items::TurnItem as CoreTurnItem; use codex_protocol::mcp::Resource as McpResource; @@ -376,8 +376,7 @@ pub struct SandboxWorkspaceWrite { #[serde(rename_all = "snake_case")] #[ts(export_to = "v2/")] pub struct ToolsV2 { - #[serde(alias = "web_search_request")] - pub web_search: Option, + pub web_search: Option, pub view_image: Option, } @@ -402,7 +401,7 @@ pub struct ProfileV2 { pub model_reasoning_summary: Option, pub model_verbosity: Option, pub web_search: Option, - pub web_search_config: Option, + pub tools: Option, pub chatgpt_base_url: Option, #[serde(default, flatten)] pub additional: HashMap, @@ -500,7 +499,6 @@ pub struct Config { pub forced_chatgpt_workspace_id: Option, pub forced_login_method: Option, pub web_search: Option, - pub web_search_config: Option, pub tools: Option, pub profile: Option, #[serde(default)] diff --git a/codex-rs/app-server/tests/suite/v2/config_rpc.rs b/codex-rs/app-server/tests/suite/v2/config_rpc.rs index 0033972fbc8..99bf0d6d2cb 100644 --- a/codex-rs/app-server/tests/suite/v2/config_rpc.rs +++ b/codex-rs/app-server/tests/suite/v2/config_rpc.rs @@ -23,11 +23,9 @@ use codex_app_server_protocol::ToolsV2; use codex_app_server_protocol::WriteStatus; use codex_core::config::set_project_trust_level; use codex_protocol::config_types::TrustLevel; -use codex_protocol::config_types::WebSearchConfig; use codex_protocol::config_types::WebSearchContextSize; -use codex_protocol::config_types::WebSearchFilters; -use codex_protocol::config_types::WebSearchUserLocation; -use codex_protocol::config_types::WebSearchUserLocationType; +use codex_protocol::config_types::WebSearchLocation; +use codex_protocol::config_types::WebSearchToolConfig; use codex_protocol::openai_models::ReasoningEffort; use codex_utils_absolute_path::AbsolutePathBuf; use pretty_assertions::assert_eq; @@ -98,8 +96,11 @@ async fn config_read_includes_tools() -> Result<()> { r#" model = "gpt-user" +[tools.web_search] +context_size = "low" +allowed_domains = ["example.com"] + [tools] -web_search = true view_image = false "#, )?; @@ -130,12 +131,28 @@ view_image = false assert_eq!( tools, ToolsV2 { - web_search: Some(true), + web_search: Some(WebSearchToolConfig { + context_size: Some(WebSearchContextSize::Low), + allowed_domains: Some(vec!["example.com".to_string()]), + location: None, + }), view_image: Some(false), } ); assert_eq!( - origins.get("tools.web_search").expect("origin").name, + origins + .get("tools.web_search.context_size") + .expect("origin") + .name, + ConfigLayerSource::User { + file: user_file.clone(), + } + ); + assert_eq!( + origins + .get("tools.web_search.allowed_domains") + .expect("origin") + .name, ConfigLayerSource::User { file: user_file.clone(), } @@ -154,23 +171,17 @@ view_image = false } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn config_read_includes_web_search_config() -> Result<()> { +async fn config_read_includes_nested_web_search_tool_config() -> Result<()> { let codex_home = TempDir::new()?; write_config( &codex_home, r#" web_search = "live" -[web_search_config] -search_context_size = "high" - -[web_search_config.filters] +[tools.web_search] +context_size = "high" allowed_domains = ["example.com"] - -[web_search_config.user_location] -country = "US" -city = "New York" -timezone = "America/New_York" +location = { country = "US", city = "New York", timezone = "America/New_York" } "#, )?; @@ -191,20 +202,17 @@ timezone = "America/New_York" let ConfigReadResponse { config, .. } = to_response(resp)?; assert_eq!( - config.web_search_config, - Some(WebSearchConfig { - filters: Some(WebSearchFilters { - allowed_domains: Some(vec!["example.com".to_string()]), - }), - user_location: Some(WebSearchUserLocation { - r#type: WebSearchUserLocationType::Approximate, + config.tools.expect("tools present").web_search, + Some(WebSearchToolConfig { + context_size: Some(WebSearchContextSize::High), + allowed_domains: Some(vec!["example.com".to_string()]), + location: Some(WebSearchLocation { country: Some("US".to_string()), region: None, city: Some("New York".to_string()), timezone: Some("America/New_York".to_string()), }), - search_context_size: Some(WebSearchContextSize::High), - }) + }), ); Ok(()) diff --git a/codex-rs/core/config.schema.json b/codex-rs/core/config.schema.json index 0eb3f31c78e..62201e2be2c 100644 --- a/codex-rs/core/config.schema.json +++ b/codex-rs/core/config.schema.json @@ -530,18 +530,15 @@ "service_tier": { "$ref": "#/definitions/ServiceTier" }, - "tools_view_image": { - "type": "boolean" + "tools": { + "$ref": "#/definitions/ToolsToml" }, - "tools_web_search": { + "tools_view_image": { "type": "boolean" }, "web_search": { "$ref": "#/definitions/WebSearchMode" }, - "web_search_config": { - "$ref": "#/definitions/WebSearchConfig" - }, "windows": { "allOf": [ { @@ -1442,8 +1439,12 @@ "type": "boolean" }, "web_search": { - "default": null, - "type": "boolean" + "allOf": [ + { + "$ref": "#/definitions/WebSearchToolConfig" + } + ], + "default": null } }, "type": "object" @@ -1551,21 +1552,6 @@ ], "type": "string" }, - "WebSearchConfig": { - "additionalProperties": false, - "properties": { - "filters": { - "$ref": "#/definitions/WebSearchFilters" - }, - "search_context_size": { - "$ref": "#/definitions/WebSearchContextSize" - }, - "user_location": { - "$ref": "#/definitions/WebSearchUserLocation" - } - }, - "type": "object" - }, "WebSearchContextSize": { "enum": [ "low", @@ -1574,14 +1560,20 @@ ], "type": "string" }, - "WebSearchFilters": { + "WebSearchLocation": { "additionalProperties": false, "properties": { - "allowed_domains": { - "items": { - "type": "string" - }, - "type": "array" + "city": { + "type": "string" + }, + "country": { + "type": "string" + }, + "region": { + "type": "string" + }, + "timezone": { + "type": "string" } }, "type": "object" @@ -1594,38 +1586,24 @@ ], "type": "string" }, - "WebSearchUserLocation": { + "WebSearchToolConfig": { "additionalProperties": false, "properties": { - "city": { - "type": "string" - }, - "country": { - "type": "string" - }, - "region": { - "type": "string" + "allowed_domains": { + "items": { + "type": "string" + }, + "type": "array" }, - "timezone": { - "type": "string" + "context_size": { + "$ref": "#/definitions/WebSearchContextSize" }, - "type": { - "allOf": [ - { - "$ref": "#/definitions/WebSearchUserLocationType" - } - ], - "default": "approximate" + "location": { + "$ref": "#/definitions/WebSearchLocation" } }, "type": "object" }, - "WebSearchUserLocationType": { - "enum": [ - "approximate" - ], - "type": "string" - }, "WindowsSandboxModeToml": { "enum": [ "elevated", @@ -2294,14 +2272,6 @@ ], "description": "Controls the web search tool mode: disabled, cached, or live." }, - "web_search_config": { - "allOf": [ - { - "$ref": "#/definitions/WebSearchConfig" - } - ], - "description": "Optional structured configuration for the web search tool." - }, "windows": { "allOf": [ { diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 44d365233d6..d3d5db7926a 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -6867,3 +6867,3877 @@ pub(crate) use tests::make_session_configuration_for_tests; #[cfg(test)] #[path = "codex_tests.rs"] mod tests; + + struct InstructionsTestCase { + slug: &'static str, + expects_apply_patch_instructions: bool, + } + + fn user_message(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: text.to_string(), + }], + end_turn: None, + phase: None, + } + } + + fn assistant_message(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: text.to_string(), + }], + end_turn: None, + phase: None, + } + } + + fn skill_message(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: text.to_string(), + }], + end_turn: None, + phase: None, + } + } + + fn developer_input_texts(items: &[ResponseItem]) -> Vec<&str> { + items + .iter() + .filter_map(|item| match item { + ResponseItem::Message { role, content, .. } if role == "developer" => { + Some(content.as_slice()) + } + _ => None, + }) + .flat_map(|content| content.iter()) + .filter_map(|item| match item { + ContentItem::InputText { text } => Some(text.as_str()), + _ => None, + }) + .collect() + } + + fn make_connector(id: &str, name: &str) -> AppInfo { + AppInfo { + id: id.to_string(), + name: name.to_string(), + description: None, + logo_url: None, + logo_url_dark: None, + distribution_channel: None, + branding: None, + app_metadata: None, + labels: None, + install_url: None, + is_accessible: true, + is_enabled: true, + plugin_display_names: Vec::new(), + } + } + + #[test] + fn assistant_message_stream_parsers_can_be_seeded_from_output_item_added_text() { + let mut parsers = AssistantMessageStreamParsers::new(false); + let item_id = "msg-1"; + + let seeded = parsers.seed_item_text(item_id, "hello doc"); + let parsed = parsers.parse_delta(item_id, "1 world"); + let tail = parsers.finish_item(item_id); + + assert_eq!(seeded.visible_text, "hello "); + assert_eq!(seeded.citations, Vec::::new()); + assert_eq!(parsed.visible_text, " world"); + assert_eq!(parsed.citations, vec!["doc1".to_string()]); + assert_eq!(tail.visible_text, ""); + assert_eq!(tail.citations, Vec::::new()); + } + + #[test] + fn assistant_message_stream_parsers_seed_buffered_prefix_stays_out_of_finish_tail() { + let mut parsers = AssistantMessageStreamParsers::new(false); + let item_id = "msg-1"; + + let seeded = parsers.seed_item_text(item_id, "hello doc world"); + let tail = parsers.finish_item(item_id); + + assert_eq!(seeded.visible_text, "hello "); + assert_eq!(seeded.citations, Vec::::new()); + assert_eq!(parsed.visible_text, " world"); + assert_eq!(parsed.citations, vec!["doc".to_string()]); + assert_eq!(tail.visible_text, ""); + assert_eq!(tail.citations, Vec::::new()); + } + + #[test] + fn assistant_message_stream_parsers_seed_plan_parser_across_added_and_delta_boundaries() { + let mut parsers = AssistantMessageStreamParsers::new(true); + let item_id = "msg-1"; + + let seeded = parsers.seed_item_text(item_id, "Intro\n\n- step\n\nOutro"); + let tail = parsers.finish_item(item_id); + + assert_eq!(seeded.visible_text, "Intro\n"); + assert_eq!( + seeded.plan_segments, + vec![ProposedPlanSegment::Normal("Intro\n".to_string())] + ); + assert_eq!(parsed.visible_text, "Outro"); + assert_eq!( + parsed.plan_segments, + vec![ + ProposedPlanSegment::ProposedPlanStart, + ProposedPlanSegment::ProposedPlanDelta("- step\n".to_string()), + ProposedPlanSegment::ProposedPlanEnd, + ProposedPlanSegment::Normal("Outro".to_string()), + ] + ); + assert_eq!(tail.visible_text, ""); + assert!(tail.plan_segments.is_empty()); + } + + fn make_mcp_tool( + server_name: &str, + tool_name: &str, + connector_id: Option<&str>, + connector_name: Option<&str>, + ) -> ToolInfo { + ToolInfo { + server_name: server_name.to_string(), + tool_name: tool_name.to_string(), + tool: Tool { + name: tool_name.to_string().into(), + title: None, + description: Some(format!("Test tool: {tool_name}").into()), + input_schema: Arc::new(JsonObject::default()), + output_schema: None, + annotations: None, + execution: None, + icons: None, + meta: None, + }, + connector_id: connector_id.map(str::to_string), + connector_name: connector_name.map(str::to_string), + plugin_display_names: Vec::new(), + } + } + + fn function_call_rollout_item(name: &str, call_id: &str) -> RolloutItem { + RolloutItem::ResponseItem(ResponseItem::FunctionCall { + id: None, + name: name.to_string(), + arguments: "{}".to_string(), + call_id: call_id.to_string(), + }) + } + + fn function_call_output_rollout_item(call_id: &str, output: &str) -> RolloutItem { + RolloutItem::ResponseItem(ResponseItem::FunctionCallOutput { + call_id: call_id.to_string(), + output: FunctionCallOutputPayload::from_text(output.to_string()), + }) + } + + #[test] + fn validated_network_policy_amendment_host_allows_normalized_match() { + let amendment = NetworkPolicyAmendment { + host: "ExAmPlE.Com.:443".to_string(), + action: NetworkPolicyRuleAction::Allow, + }; + let context = NetworkApprovalContext { + host: "example.com".to_string(), + protocol: NetworkApprovalProtocol::Https, + }; + + let host = Session::validated_network_policy_amendment_host(&amendment, &context) + .expect("normalized hosts should match"); + + assert_eq!(host, "example.com"); + } + + #[test] + fn validated_network_policy_amendment_host_rejects_mismatch() { + let amendment = NetworkPolicyAmendment { + host: "evil.example.com".to_string(), + action: NetworkPolicyRuleAction::Deny, + }; + let context = NetworkApprovalContext { + host: "api.example.com".to_string(), + protocol: NetworkApprovalProtocol::Https, + }; + + let err = Session::validated_network_policy_amendment_host(&amendment, &context) + .expect_err("mismatched hosts should be rejected"); + + let message = err.to_string(); + assert!(message.contains("does not match approved host")); + } + + #[tokio::test] + async fn get_base_instructions_no_user_content() { + let prompt_with_apply_patch_instructions = + include_str!("../prompt_with_apply_patch_instructions.md"); + let models_response: ModelsResponse = + serde_json::from_str(include_str!("../models.json")).expect("valid models.json"); + let model_info_for_slug = |slug: &str, config: &Config| { + let model = models_response + .models + .iter() + .find(|candidate| candidate.slug == slug) + .cloned() + .unwrap_or_else(|| panic!("model slug {slug} is missing from models.json")); + model_info::with_config_overrides(model, config) + }; + let test_cases = vec![ + InstructionsTestCase { + slug: "gpt-5", + expects_apply_patch_instructions: false, + }, + InstructionsTestCase { + slug: "gpt-5.1", + expects_apply_patch_instructions: false, + }, + InstructionsTestCase { + slug: "gpt-5.1-codex", + expects_apply_patch_instructions: false, + }, + InstructionsTestCase { + slug: "gpt-5.1-codex-max", + expects_apply_patch_instructions: false, + }, + ]; + + let (session, _turn_context) = make_session_and_context().await; + let config = test_config(); + + for test_case in test_cases { + let model_info = model_info_for_slug(test_case.slug, &config); + if test_case.expects_apply_patch_instructions { + assert_eq!( + model_info.base_instructions.as_str(), + prompt_with_apply_patch_instructions + ); + } + + { + let mut state = session.state.lock().await; + state.session_configuration.base_instructions = + model_info.base_instructions.clone(); + } + + let base_instructions = session.get_base_instructions().await; + assert_eq!(base_instructions.text, model_info.base_instructions); + } + } + + #[tokio::test] + async fn reload_user_config_layer_updates_effective_apps_config() { + let (session, _turn_context) = make_session_and_context().await; + let codex_home = session.codex_home().await; + std::fs::create_dir_all(&codex_home).expect("create codex home"); + let config_toml_path = codex_home.join(CONFIG_TOML_FILE); + std::fs::write( + &config_toml_path, + "[apps.calendar]\nenabled = false\ndestructive_enabled = false\n", + ) + .expect("write user config"); + + session.reload_user_config_layer().await; + + let config = session.get_config().await; + let apps_toml = config + .config_layer_stack + .effective_config() + .as_table() + .and_then(|table| table.get("apps")) + .cloned() + .expect("apps table"); + let apps = crate::config::types::AppsConfigToml::deserialize(apps_toml) + .expect("deserialize apps config"); + let app = apps + .apps + .get("calendar") + .expect("calendar app config exists"); + + assert!(!app.enabled); + assert_eq!(app.destructive_enabled, Some(false)); + } + + #[test] + fn filter_connectors_for_input_skips_duplicate_slug_mentions() { + let connectors = vec![ + make_connector("one", "Foo Bar"), + make_connector("two", "Foo-Bar"), + ]; + let input = vec![user_message("use $foo-bar")]; + let explicitly_enabled_connectors = HashSet::new(); + let skill_name_counts_lower = HashMap::new(); + + let selected = filter_connectors_for_input( + &connectors, + &input, + &explicitly_enabled_connectors, + &skill_name_counts_lower, + ); + + assert_eq!(selected, Vec::new()); + } + + #[test] + fn filter_connectors_for_input_skips_when_skill_name_conflicts() { + let connectors = vec![make_connector("one", "Todoist")]; + let input = vec![user_message("use $todoist")]; + let explicitly_enabled_connectors = HashSet::new(); + let skill_name_counts_lower = HashMap::from([("todoist".to_string(), 1)]); + + let selected = filter_connectors_for_input( + &connectors, + &input, + &explicitly_enabled_connectors, + &skill_name_counts_lower, + ); + + assert_eq!(selected, Vec::new()); + } + + #[test] + fn filter_connectors_for_input_skips_disabled_connectors() { + let mut connector = make_connector("calendar", "Calendar"); + connector.is_enabled = false; + let input = vec![user_message("use $calendar")]; + let explicitly_enabled_connectors = HashSet::new(); + let selected = filter_connectors_for_input( + &[connector], + &input, + &explicitly_enabled_connectors, + &HashMap::new(), + ); + + assert_eq!(selected, Vec::new()); + } + + #[test] + fn collect_explicit_app_ids_from_skill_items_includes_linked_mentions() { + let connectors = vec![make_connector("calendar", "Calendar")]; + let skill_items = vec![skill_message( + "\ndemo\n/tmp/skills/demo/SKILL.md\nuse [$calendar](app://calendar)\n", + )]; + + let connector_ids = + collect_explicit_app_ids_from_skill_items(&skill_items, &connectors, &HashMap::new()); + + assert_eq!(connector_ids, HashSet::from(["calendar".to_string()])); + } + + #[test] + fn collect_explicit_app_ids_from_skill_items_resolves_unambiguous_plain_mentions() { + let connectors = vec![make_connector("calendar", "Calendar")]; + let skill_items = vec![skill_message( + "\ndemo\n/tmp/skills/demo/SKILL.md\nuse $calendar\n", + )]; + + let connector_ids = + collect_explicit_app_ids_from_skill_items(&skill_items, &connectors, &HashMap::new()); + + assert_eq!(connector_ids, HashSet::from(["calendar".to_string()])); + } + + #[test] + fn collect_explicit_app_ids_from_skill_items_skips_plain_mentions_with_skill_conflicts() { + let connectors = vec![make_connector("calendar", "Calendar")]; + let skill_items = vec![skill_message( + "\ndemo\n/tmp/skills/demo/SKILL.md\nuse $calendar\n", + )]; + let skill_name_counts_lower = HashMap::from([("calendar".to_string(), 1)]); + + let connector_ids = collect_explicit_app_ids_from_skill_items( + &skill_items, + &connectors, + &skill_name_counts_lower, + ); + + assert_eq!(connector_ids, HashSet::::new()); + } + + #[test] + fn non_app_mcp_tools_remain_visible_without_search_selection() { + let mcp_tools = HashMap::from([ + ( + "mcp__codex_apps__calendar_create_event".to_string(), + make_mcp_tool( + CODEX_APPS_MCP_SERVER_NAME, + "calendar_create_event", + Some("calendar"), + Some("Calendar"), + ), + ), + ( + "mcp__rmcp__echo".to_string(), + make_mcp_tool("rmcp", "echo", None, None), + ), + ]); + + let mut selected_mcp_tools = mcp_tools + .iter() + .filter(|(_, tool)| tool.server_name != CODEX_APPS_MCP_SERVER_NAME) + .map(|(name, tool)| (name.clone(), tool.clone())) + .collect::>(); + + let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools); + let explicitly_enabled_connectors = HashSet::new(); + let connectors = filter_connectors_for_input( + &connectors, + &[user_message("run echo")], + &explicitly_enabled_connectors, + &HashMap::new(), + ); + let apps_mcp_tools = filter_codex_apps_mcp_tools_only(&mcp_tools, &connectors); + selected_mcp_tools.extend(apps_mcp_tools); + + let mut tool_names: Vec = selected_mcp_tools.into_keys().collect(); + tool_names.sort(); + assert_eq!(tool_names, vec!["mcp__rmcp__echo".to_string()]); + } + + #[test] + fn search_tool_selection_keeps_codex_apps_tools_without_mentions() { + let selected_tool_names = vec![ + "mcp__codex_apps__calendar_create_event".to_string(), + "mcp__rmcp__echo".to_string(), + ]; + let mcp_tools = HashMap::from([ + ( + "mcp__codex_apps__calendar_create_event".to_string(), + make_mcp_tool( + CODEX_APPS_MCP_SERVER_NAME, + "calendar_create_event", + Some("calendar"), + Some("Calendar"), + ), + ), + ( + "mcp__rmcp__echo".to_string(), + make_mcp_tool("rmcp", "echo", None, None), + ), + ]); + + let mut selected_mcp_tools = filter_mcp_tools_by_name(&mcp_tools, &selected_tool_names); + let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools); + let explicitly_enabled_connectors = HashSet::new(); + let connectors = filter_connectors_for_input( + &connectors, + &[user_message("run the selected tools")], + &explicitly_enabled_connectors, + &HashMap::new(), + ); + let apps_mcp_tools = filter_codex_apps_mcp_tools_only(&mcp_tools, &connectors); + selected_mcp_tools.extend(apps_mcp_tools); + + let mut tool_names: Vec = selected_mcp_tools.into_keys().collect(); + tool_names.sort(); + assert_eq!( + tool_names, + vec![ + "mcp__codex_apps__calendar_create_event".to_string(), + "mcp__rmcp__echo".to_string(), + ] + ); + } + + #[test] + fn apps_mentions_add_codex_apps_tools_to_search_selected_set() { + let selected_tool_names = vec!["mcp__rmcp__echo".to_string()]; + let mcp_tools = HashMap::from([ + ( + "mcp__codex_apps__calendar_create_event".to_string(), + make_mcp_tool( + CODEX_APPS_MCP_SERVER_NAME, + "calendar_create_event", + Some("calendar"), + Some("Calendar"), + ), + ), + ( + "mcp__rmcp__echo".to_string(), + make_mcp_tool("rmcp", "echo", None, None), + ), + ]); + + let mut selected_mcp_tools = filter_mcp_tools_by_name(&mcp_tools, &selected_tool_names); + let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools); + let explicitly_enabled_connectors = HashSet::new(); + let connectors = filter_connectors_for_input( + &connectors, + &[user_message("use $calendar and then echo the response")], + &explicitly_enabled_connectors, + &HashMap::new(), + ); + let apps_mcp_tools = filter_codex_apps_mcp_tools_only(&mcp_tools, &connectors); + selected_mcp_tools.extend(apps_mcp_tools); + + let mut tool_names: Vec = selected_mcp_tools.into_keys().collect(); + tool_names.sort(); + assert_eq!( + tool_names, + vec![ + "mcp__codex_apps__calendar_create_event".to_string(), + "mcp__rmcp__echo".to_string(), + ] + ); + } + + #[test] + fn extract_mcp_tool_selection_from_rollout_reads_search_tool_output() { + let rollout_items = vec![ + function_call_rollout_item(SEARCH_TOOL_BM25_TOOL_NAME, "search-1"), + function_call_output_rollout_item( + "search-1", + &json!({ + "active_selected_tools": [ + "mcp__codex_apps__calendar_create_event", + "mcp__codex_apps__calendar_list_events", + ], + }) + .to_string(), + ), + ]; + + let selected = Session::extract_mcp_tool_selection_from_rollout(&rollout_items); + assert_eq!( + selected, + Some(vec![ + "mcp__codex_apps__calendar_create_event".to_string(), + "mcp__codex_apps__calendar_list_events".to_string(), + ]) + ); + } + + #[test] + fn extract_mcp_tool_selection_from_rollout_latest_valid_payload_wins() { + let rollout_items = vec![ + function_call_rollout_item(SEARCH_TOOL_BM25_TOOL_NAME, "search-1"), + function_call_output_rollout_item( + "search-1", + &json!({ + "active_selected_tools": ["mcp__codex_apps__calendar_create_event"], + }) + .to_string(), + ), + function_call_rollout_item(SEARCH_TOOL_BM25_TOOL_NAME, "search-2"), + function_call_output_rollout_item( + "search-2", + &json!({ + "active_selected_tools": ["mcp__codex_apps__calendar_delete_event"], + }) + .to_string(), + ), + ]; + + let selected = Session::extract_mcp_tool_selection_from_rollout(&rollout_items); + assert_eq!( + selected, + Some(vec!["mcp__codex_apps__calendar_delete_event".to_string(),]) + ); + } + + #[test] + fn extract_mcp_tool_selection_from_rollout_ignores_non_search_and_malformed_payloads() { + let rollout_items = vec![ + function_call_rollout_item("shell", "shell-1"), + function_call_output_rollout_item( + "shell-1", + &json!({ + "active_selected_tools": ["mcp__codex_apps__should_be_ignored"], + }) + .to_string(), + ), + function_call_rollout_item(SEARCH_TOOL_BM25_TOOL_NAME, "search-1"), + function_call_output_rollout_item("search-1", "{not-json"), + function_call_output_rollout_item( + "unknown-search-call", + &json!({ + "active_selected_tools": ["mcp__codex_apps__also_ignored"], + }) + .to_string(), + ), + function_call_output_rollout_item( + "search-1", + &json!({ + "active_selected_tools": ["mcp__codex_apps__calendar_list_events"], + }) + .to_string(), + ), + ]; + + let selected = Session::extract_mcp_tool_selection_from_rollout(&rollout_items); + assert_eq!( + selected, + Some(vec!["mcp__codex_apps__calendar_list_events".to_string(),]) + ); + } + + #[test] + fn extract_mcp_tool_selection_from_rollout_returns_none_without_valid_search_output() { + let rollout_items = vec![function_call_rollout_item( + SEARCH_TOOL_BM25_TOOL_NAME, + "search-1", + )]; + let selected = Session::extract_mcp_tool_selection_from_rollout(&rollout_items); + assert_eq!(selected, None); + } + + #[tokio::test] + async fn reconstruct_history_matches_live_compactions() { + let (session, turn_context) = make_session_and_context().await; + let (rollout_items, expected) = sample_rollout(&session, &turn_context).await; + + let reconstruction_turn = session.new_default_turn().await; + let reconstructed = session + .reconstruct_history_from_rollout(reconstruction_turn.as_ref(), &rollout_items) + .await; + + assert_eq!(expected, reconstructed.history); + } + + #[tokio::test] + async fn reconstruct_history_uses_replacement_history_verbatim() { + let (session, turn_context) = make_session_and_context().await; + let summary_item = ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "summary".to_string(), + }], + end_turn: None, + phase: None, + }; + let replacement_history = vec![ + summary_item.clone(), + ResponseItem::Message { + id: None, + role: "developer".to_string(), + content: vec![ContentItem::InputText { + text: "stale developer instructions".to_string(), + }], + end_turn: None, + phase: None, + }, + ]; + let rollout_items = vec![RolloutItem::Compacted(CompactedItem { + message: String::new(), + replacement_history: Some(replacement_history.clone()), + })]; + + let reconstructed = session + .reconstruct_history_from_rollout(&turn_context, &rollout_items) + .await; + + assert_eq!(reconstructed.history, replacement_history); + } + + #[tokio::test] + async fn record_initial_history_reconstructs_resumed_transcript() { + let (session, turn_context) = make_session_and_context().await; + let (rollout_items, expected) = sample_rollout(&session, &turn_context).await; + + session + .record_initial_history(InitialHistory::Resumed(ResumedHistory { + conversation_id: ThreadId::default(), + history: rollout_items, + rollout_path: PathBuf::from("/tmp/resume.jsonl"), + })) + .await; + + let history = session.state.lock().await.clone_history(); + assert_eq!(expected, history.raw_items()); + } + + #[tokio::test] + async fn resumed_history_injects_initial_context_on_first_context_update_only() { + let (session, turn_context) = make_session_and_context().await; + let (rollout_items, mut expected) = sample_rollout(&session, &turn_context).await; + + session + .record_initial_history(InitialHistory::Resumed(ResumedHistory { + conversation_id: ThreadId::default(), + history: rollout_items, + rollout_path: PathBuf::from("/tmp/resume.jsonl"), + })) + .await; + + let history_before_seed = session.state.lock().await.clone_history(); + assert_eq!(expected, history_before_seed.raw_items()); + + session + .record_context_updates_and_set_reference_context_item(&turn_context) + .await; + expected.extend(session.build_initial_context(&turn_context).await); + let history_after_seed = session.clone_history().await; + assert_eq!(expected, history_after_seed.raw_items()); + + session + .record_context_updates_and_set_reference_context_item(&turn_context) + .await; + let history_after_second_seed = session.clone_history().await; + assert_eq!( + history_after_seed.raw_items(), + history_after_second_seed.raw_items() + ); + } + + #[tokio::test] + async fn record_initial_history_seeds_token_info_from_rollout() { + let (session, turn_context) = make_session_and_context().await; + let (mut rollout_items, _expected) = sample_rollout(&session, &turn_context).await; + + let info1 = TokenUsageInfo { + total_token_usage: TokenUsage { + input_tokens: 10, + cached_input_tokens: 0, + output_tokens: 20, + reasoning_output_tokens: 0, + total_tokens: 30, + }, + last_token_usage: TokenUsage { + input_tokens: 3, + cached_input_tokens: 0, + output_tokens: 4, + reasoning_output_tokens: 0, + total_tokens: 7, + }, + model_context_window: Some(1_000), + }; + let info2 = TokenUsageInfo { + total_token_usage: TokenUsage { + input_tokens: 100, + cached_input_tokens: 50, + output_tokens: 200, + reasoning_output_tokens: 25, + total_tokens: 375, + }, + last_token_usage: TokenUsage { + input_tokens: 10, + cached_input_tokens: 0, + output_tokens: 20, + reasoning_output_tokens: 5, + total_tokens: 35, + }, + model_context_window: Some(2_000), + }; + + rollout_items.push(RolloutItem::EventMsg(EventMsg::TokenCount( + TokenCountEvent { + info: Some(info1), + rate_limits: None, + }, + ))); + rollout_items.push(RolloutItem::EventMsg(EventMsg::TokenCount( + TokenCountEvent { + info: None, + rate_limits: None, + }, + ))); + rollout_items.push(RolloutItem::EventMsg(EventMsg::TokenCount( + TokenCountEvent { + info: Some(info2.clone()), + rate_limits: None, + }, + ))); + rollout_items.push(RolloutItem::EventMsg(EventMsg::TokenCount( + TokenCountEvent { + info: None, + rate_limits: None, + }, + ))); + + session + .record_initial_history(InitialHistory::Resumed(ResumedHistory { + conversation_id: ThreadId::default(), + history: rollout_items, + rollout_path: PathBuf::from("/tmp/resume.jsonl"), + })) + .await; + + let actual = session.state.lock().await.token_info(); + assert_eq!(actual, Some(info2)); + } + + #[tokio::test] + async fn recompute_token_usage_uses_session_base_instructions() { + let (session, turn_context) = make_session_and_context().await; + + let override_instructions = "SESSION_OVERRIDE_INSTRUCTIONS_ONLY".repeat(120); + { + let mut state = session.state.lock().await; + state.session_configuration.base_instructions = override_instructions.clone(); + } + + let item = user_message("hello"); + session + .record_into_history(std::slice::from_ref(&item), &turn_context) + .await; + + let history = session.clone_history().await; + let session_base_instructions = BaseInstructions { + text: override_instructions, + }; + let expected_tokens = history + .estimate_token_count_with_base_instructions(&session_base_instructions) + .expect("estimate with session base instructions"); + let model_estimated_tokens = history + .estimate_token_count(&turn_context) + .expect("estimate with model instructions"); + assert_ne!(expected_tokens, model_estimated_tokens); + + session.recompute_token_usage(&turn_context).await; + + let actual_tokens = session + .state + .lock() + .await + .token_info() + .expect("token info") + .last_token_usage + .total_tokens; + assert_eq!(actual_tokens, expected_tokens.max(0)); + } + + #[tokio::test] + async fn recompute_token_usage_updates_model_context_window() { + let (session, mut turn_context) = make_session_and_context().await; + + { + let mut state = session.state.lock().await; + state.set_token_info(Some(TokenUsageInfo { + total_token_usage: TokenUsage::default(), + last_token_usage: TokenUsage::default(), + model_context_window: Some(258_400), + })); + } + + turn_context.model_info.context_window = Some(128_000); + turn_context.model_info.effective_context_window_percent = 100; + + session.recompute_token_usage(&turn_context).await; + + let actual = session.state.lock().await.token_info().expect("token info"); + assert_eq!(actual.model_context_window, Some(128_000)); + } + + #[tokio::test] + async fn record_initial_history_reconstructs_forked_transcript() { + let (session, turn_context) = make_session_and_context().await; + let (rollout_items, mut expected) = sample_rollout(&session, &turn_context).await; + + session + .record_initial_history(InitialHistory::Forked(rollout_items)) + .await; + + let reconstruction_turn = session.new_default_turn().await; + expected.extend( + session + .build_initial_context(reconstruction_turn.as_ref()) + .await, + ); + let history = session.state.lock().await.clone_history(); + assert_eq!(expected, history.raw_items()); + } + + #[tokio::test] + async fn record_initial_history_forked_hydrates_previous_turn_settings() { + let (session, turn_context) = make_session_and_context().await; + let previous_model = "forked-rollout-model"; + let previous_context_item = TurnContextItem { + turn_id: Some(turn_context.sub_id.clone()), + trace_id: turn_context.trace_id.clone(), + cwd: turn_context.cwd.clone(), + current_date: turn_context.current_date.clone(), + timezone: turn_context.timezone.clone(), + approval_policy: turn_context.approval_policy.value(), + sandbox_policy: turn_context.sandbox_policy.get().clone(), + network: None, + model: previous_model.to_string(), + personality: turn_context.personality, + collaboration_mode: Some(turn_context.collaboration_mode.clone()), + realtime_active: Some(turn_context.realtime_active), + effort: turn_context.reasoning_effort, + summary: turn_context.reasoning_summary, + user_instructions: None, + developer_instructions: None, + final_output_json_schema: None, + truncation_policy: Some(turn_context.truncation_policy.into()), + }; + let turn_id = previous_context_item + .turn_id + .clone() + .expect("turn context should have turn_id"); + let rollout_items = vec![ + RolloutItem::EventMsg(EventMsg::TurnStarted( + codex_protocol::protocol::TurnStartedEvent { + turn_id: turn_id.clone(), + model_context_window: Some(128_000), + collaboration_mode_kind: ModeKind::Default, + }, + )), + RolloutItem::EventMsg(EventMsg::UserMessage( + codex_protocol::protocol::UserMessageEvent { + message: "forked seed".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + }, + )), + RolloutItem::TurnContext(previous_context_item), + RolloutItem::EventMsg(EventMsg::TurnComplete( + codex_protocol::protocol::TurnCompleteEvent { + turn_id, + last_agent_message: None, + }, + )), + ]; + + session + .record_initial_history(InitialHistory::Forked(rollout_items)) + .await; + + assert_eq!( + session.previous_turn_settings().await, + Some(PreviousTurnSettings { + model: previous_model.to_string(), + realtime_active: Some(turn_context.realtime_active), + }) + ); + } + + #[tokio::test] + async fn thread_rollback_drops_last_turn_from_history() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + let rollout_path = attach_rollout_recorder(&sess).await; + + let initial_context = sess.build_initial_context(tc.as_ref()).await; + let turn_1 = vec![ + user_message("turn 1 user"), + assistant_message("turn 1 assistant"), + ]; + let turn_2 = vec![ + user_message("turn 2 user"), + assistant_message("turn 2 assistant"), + ]; + let mut full_history = Vec::new(); + full_history.extend(initial_context.clone()); + full_history.extend(turn_1.clone()); + full_history.extend(turn_2); + sess.replace_history(full_history.clone(), Some(tc.to_turn_context_item())) + .await; + let rollout_items: Vec = full_history + .into_iter() + .map(RolloutItem::ResponseItem) + .collect(); + sess.persist_rollout_items(&rollout_items).await; + sess.set_previous_turn_settings(Some(PreviousTurnSettings { + model: "stale-model".to_string(), + realtime_active: Some(tc.realtime_active), + })) + .await; + { + let mut state = sess.state.lock().await; + state.set_reference_context_item(Some(tc.to_turn_context_item())); + } + + handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; + + let rollback_event = wait_for_thread_rolled_back(&rx).await; + assert_eq!(rollback_event.num_turns, 1); + + let mut expected = Vec::new(); + expected.extend(initial_context); + expected.extend(turn_1); + + let history = sess.clone_history().await; + assert_eq!(expected, history.raw_items()); + assert_eq!(sess.previous_turn_settings().await, None); + assert!(sess.reference_context_item().await.is_none()); + + let InitialHistory::Resumed(resumed) = RolloutRecorder::get_rollout_history(&rollout_path) + .await + .expect("read rollout history") + else { + panic!("expected resumed rollout history"); + }; + assert!(resumed.history.iter().any(|item| { + matches!( + item, + RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) + if rollback.num_turns == 1 + ) + })); + } + + #[tokio::test] + async fn thread_rollback_clears_history_when_num_turns_exceeds_existing_turns() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + attach_rollout_recorder(&sess).await; + + let initial_context = sess.build_initial_context(tc.as_ref()).await; + let turn_1 = vec![user_message("turn 1 user")]; + let mut full_history = Vec::new(); + full_history.extend(initial_context.clone()); + full_history.extend(turn_1); + sess.replace_history(full_history.clone(), Some(tc.to_turn_context_item())) + .await; + let rollout_items: Vec = full_history + .into_iter() + .map(RolloutItem::ResponseItem) + .collect(); + sess.persist_rollout_items(&rollout_items).await; + + handlers::thread_rollback(&sess, "sub-1".to_string(), 99).await; + + let rollback_event = wait_for_thread_rolled_back(&rx).await; + assert_eq!(rollback_event.num_turns, 99); + + let history = sess.clone_history().await; + assert_eq!(initial_context, history.raw_items()); + } + + #[tokio::test] + async fn thread_rollback_fails_without_persisted_rollout_path() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + + let initial_context = sess.build_initial_context(tc.as_ref()).await; + sess.record_into_history(&initial_context, tc.as_ref()) + .await; + + handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; + + let error_event = wait_for_thread_rollback_failed(&rx).await; + assert_eq!( + error_event.message, + "thread rollback requires a persisted rollout path" + ); + assert_eq!( + error_event.codex_error_info, + Some(CodexErrorInfo::ThreadRollbackFailed) + ); + assert_eq!(sess.clone_history().await.raw_items(), initial_context); + } + + #[tokio::test] + async fn thread_rollback_recomputes_previous_turn_settings_and_reference_context_from_replay() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + attach_rollout_recorder(&sess).await; + + let first_context_item = tc.to_turn_context_item(); + let first_turn_id = first_context_item + .turn_id + .clone() + .expect("turn context should have turn_id"); + let mut rolled_back_context_item = first_context_item.clone(); + rolled_back_context_item.turn_id = Some("rolled-back-turn".to_string()); + rolled_back_context_item.model = "rolled-back-model".to_string(); + let rolled_back_turn_id = rolled_back_context_item + .turn_id + .clone() + .expect("turn context should have turn_id"); + let turn_one_user = user_message("turn 1 user"); + let turn_one_assistant = assistant_message("turn 1 assistant"); + let turn_two_user = user_message("turn 2 user"); + let turn_two_assistant = assistant_message("turn 2 assistant"); + + sess.persist_rollout_items(&[ + RolloutItem::EventMsg(EventMsg::TurnStarted( + codex_protocol::protocol::TurnStartedEvent { + turn_id: first_turn_id.clone(), + model_context_window: Some(128_000), + collaboration_mode_kind: ModeKind::Default, + }, + )), + RolloutItem::EventMsg(EventMsg::UserMessage( + codex_protocol::protocol::UserMessageEvent { + message: "turn 1 user".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + }, + )), + RolloutItem::TurnContext(first_context_item.clone()), + RolloutItem::ResponseItem(turn_one_user.clone()), + RolloutItem::ResponseItem(turn_one_assistant.clone()), + RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { + turn_id: first_turn_id, + last_agent_message: None, + })), + RolloutItem::EventMsg(EventMsg::TurnStarted( + codex_protocol::protocol::TurnStartedEvent { + turn_id: rolled_back_turn_id.clone(), + model_context_window: Some(128_000), + collaboration_mode_kind: ModeKind::Default, + }, + )), + RolloutItem::EventMsg(EventMsg::UserMessage( + codex_protocol::protocol::UserMessageEvent { + message: "turn 2 user".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + }, + )), + RolloutItem::TurnContext(rolled_back_context_item), + RolloutItem::ResponseItem(turn_two_user), + RolloutItem::ResponseItem(turn_two_assistant), + RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { + turn_id: rolled_back_turn_id, + last_agent_message: None, + })), + ]) + .await; + sess.replace_history( + vec![assistant_message("stale history")], + Some(first_context_item.clone()), + ) + .await; + sess.set_previous_turn_settings(Some(PreviousTurnSettings { + model: "stale-model".to_string(), + realtime_active: None, + })) + .await; + + handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; + let rollback_event = wait_for_thread_rolled_back(&rx).await; + assert_eq!(rollback_event.num_turns, 1); + + assert_eq!( + sess.clone_history().await.raw_items(), + vec![turn_one_user, turn_one_assistant] + ); + assert_eq!( + sess.previous_turn_settings().await, + Some(PreviousTurnSettings { + model: tc.model_info.slug.clone(), + realtime_active: Some(tc.realtime_active), + }) + ); + assert_eq!( + serde_json::to_value(sess.reference_context_item().await) + .expect("serialize replay reference context item"), + serde_json::to_value(Some(first_context_item)) + .expect("serialize expected reference context item") + ); + } + + #[tokio::test] + async fn thread_rollback_persists_marker_and_replays_cumulatively() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + let rollout_path = attach_rollout_recorder(&sess).await; + let turn_context_item = tc.to_turn_context_item(); + + sess.persist_rollout_items(&[ + RolloutItem::EventMsg(EventMsg::TurnStarted( + codex_protocol::protocol::TurnStartedEvent { + turn_id: "turn-1".to_string(), + model_context_window: Some(128_000), + collaboration_mode_kind: ModeKind::Default, + }, + )), + RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent { + message: "turn 1 user".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + })), + RolloutItem::TurnContext(turn_context_item.clone()), + RolloutItem::ResponseItem(user_message("turn 1 user")), + RolloutItem::ResponseItem(assistant_message("turn 1 assistant")), + RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { + turn_id: "turn-1".to_string(), + last_agent_message: None, + })), + RolloutItem::EventMsg(EventMsg::TurnStarted( + codex_protocol::protocol::TurnStartedEvent { + turn_id: "turn-2".to_string(), + model_context_window: Some(128_000), + collaboration_mode_kind: ModeKind::Default, + }, + )), + RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent { + message: "turn 2 user".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + })), + RolloutItem::TurnContext(turn_context_item.clone()), + RolloutItem::ResponseItem(user_message("turn 2 user")), + RolloutItem::ResponseItem(assistant_message("turn 2 assistant")), + RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { + turn_id: "turn-2".to_string(), + last_agent_message: None, + })), + RolloutItem::EventMsg(EventMsg::TurnStarted( + codex_protocol::protocol::TurnStartedEvent { + turn_id: "turn-3".to_string(), + model_context_window: Some(128_000), + collaboration_mode_kind: ModeKind::Default, + }, + )), + RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent { + message: "turn 3 user".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + })), + RolloutItem::TurnContext(turn_context_item), + RolloutItem::ResponseItem(user_message("turn 3 user")), + RolloutItem::ResponseItem(assistant_message("turn 3 assistant")), + RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { + turn_id: "turn-3".to_string(), + last_agent_message: None, + })), + ]) + .await; + + handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; + let first_rollback = wait_for_thread_rolled_back(&rx).await; + assert_eq!(first_rollback.num_turns, 1); + handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; + let second_rollback = wait_for_thread_rolled_back(&rx).await; + assert_eq!(second_rollback.num_turns, 1); + + assert_eq!( + sess.clone_history().await.raw_items(), + vec![ + user_message("turn 1 user"), + assistant_message("turn 1 assistant") + ] + ); + + let InitialHistory::Resumed(resumed) = RolloutRecorder::get_rollout_history(&rollout_path) + .await + .expect("read rollout history") + else { + panic!("expected resumed rollout history"); + }; + let rollback_markers = resumed + .history + .iter() + .filter(|item| matches!(item, RolloutItem::EventMsg(EventMsg::ThreadRolledBack(_)))) + .count(); + assert_eq!(rollback_markers, 2); + } + + #[tokio::test] + async fn thread_rollback_fails_when_turn_in_progress() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + + let initial_context = sess.build_initial_context(tc.as_ref()).await; + sess.record_into_history(&initial_context, tc.as_ref()) + .await; + + *sess.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); + handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; + + let error_event = wait_for_thread_rollback_failed(&rx).await; + assert_eq!( + error_event.codex_error_info, + Some(CodexErrorInfo::ThreadRollbackFailed) + ); + + let history = sess.clone_history().await; + assert_eq!(initial_context, history.raw_items()); + } + + #[tokio::test] + async fn thread_rollback_fails_when_num_turns_is_zero() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + + let initial_context = sess.build_initial_context(tc.as_ref()).await; + sess.record_into_history(&initial_context, tc.as_ref()) + .await; + + handlers::thread_rollback(&sess, "sub-1".to_string(), 0).await; + + let error_event = wait_for_thread_rollback_failed(&rx).await; + assert_eq!(error_event.message, "num_turns must be >= 1"); + assert_eq!( + error_event.codex_error_info, + Some(CodexErrorInfo::ThreadRollbackFailed) + ); + + let history = sess.clone_history().await; + assert_eq!(initial_context, history.raw_items()); + } + + #[tokio::test] + async fn set_rate_limits_retains_previous_credits() { + let codex_home = tempfile::tempdir().expect("create temp dir"); + let config = build_test_config(codex_home.path()).await; + let config = Arc::new(config); + let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); + let model_info = + ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); + let reasoning_effort = config.model_reasoning_effort; + let collaboration_mode = CollaborationMode { + mode: ModeKind::Default, + settings: Settings { + model, + reasoning_effort, + developer_instructions: None, + }, + }; + let session_configuration = SessionConfiguration { + provider: config.model_provider.clone(), + collaboration_mode, + model_reasoning_summary: config.model_reasoning_summary, + developer_instructions: config.developer_instructions.clone(), + user_instructions: config.user_instructions.clone(), + service_tier: None, + personality: config.personality, + base_instructions: config + .base_instructions + .clone() + .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), + compact_prompt: config.compact_prompt.clone(), + approval_policy: config.permissions.approval_policy.clone(), + sandbox_policy: config.permissions.sandbox_policy.clone(), + windows_sandbox_level: WindowsSandboxLevel::from_config(&config), + cwd: config.cwd.clone(), + codex_home: config.codex_home.clone(), + thread_name: None, + original_config_do_not_use: Arc::clone(&config), + metrics_service_name: None, + app_server_client_name: None, + session_source: SessionSource::Exec, + dynamic_tools: Vec::new(), + persist_extended_history: false, + inherited_shell_snapshot: None, + }; + + let mut state = SessionState::new(session_configuration); + let initial = RateLimitSnapshot { + limit_id: None, + limit_name: None, + primary: Some(RateLimitWindow { + used_percent: 10.0, + window_minutes: Some(15), + resets_at: Some(1_700), + }), + secondary: None, + credits: Some(CreditsSnapshot { + has_credits: true, + unlimited: false, + balance: Some("10.00".to_string()), + }), + plan_type: Some(codex_protocol::account::PlanType::Plus), + }; + state.set_rate_limits(initial.clone()); + + let update = RateLimitSnapshot { + limit_id: Some("codex_other".to_string()), + limit_name: Some("codex_other".to_string()), + primary: Some(RateLimitWindow { + used_percent: 40.0, + window_minutes: Some(30), + resets_at: Some(1_800), + }), + secondary: Some(RateLimitWindow { + used_percent: 5.0, + window_minutes: Some(60), + resets_at: Some(1_900), + }), + credits: None, + plan_type: None, + }; + state.set_rate_limits(update.clone()); + + assert_eq!( + state.latest_rate_limits, + Some(RateLimitSnapshot { + limit_id: Some("codex_other".to_string()), + limit_name: Some("codex_other".to_string()), + primary: update.primary.clone(), + secondary: update.secondary, + credits: initial.credits, + plan_type: initial.plan_type, + }) + ); + } + + #[tokio::test] + async fn set_rate_limits_updates_plan_type_when_present() { + let codex_home = tempfile::tempdir().expect("create temp dir"); + let config = build_test_config(codex_home.path()).await; + let config = Arc::new(config); + let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); + let model_info = + ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); + let reasoning_effort = config.model_reasoning_effort; + let collaboration_mode = CollaborationMode { + mode: ModeKind::Default, + settings: Settings { + model, + reasoning_effort, + developer_instructions: None, + }, + }; + let session_configuration = SessionConfiguration { + provider: config.model_provider.clone(), + collaboration_mode, + model_reasoning_summary: config.model_reasoning_summary, + developer_instructions: config.developer_instructions.clone(), + user_instructions: config.user_instructions.clone(), + service_tier: None, + personality: config.personality, + base_instructions: config + .base_instructions + .clone() + .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), + compact_prompt: config.compact_prompt.clone(), + approval_policy: config.permissions.approval_policy.clone(), + sandbox_policy: config.permissions.sandbox_policy.clone(), + windows_sandbox_level: WindowsSandboxLevel::from_config(&config), + cwd: config.cwd.clone(), + codex_home: config.codex_home.clone(), + thread_name: None, + original_config_do_not_use: Arc::clone(&config), + metrics_service_name: None, + app_server_client_name: None, + session_source: SessionSource::Exec, + dynamic_tools: Vec::new(), + persist_extended_history: false, + inherited_shell_snapshot: None, + }; + + let mut state = SessionState::new(session_configuration); + let initial = RateLimitSnapshot { + limit_id: None, + limit_name: None, + primary: Some(RateLimitWindow { + used_percent: 15.0, + window_minutes: Some(20), + resets_at: Some(1_600), + }), + secondary: Some(RateLimitWindow { + used_percent: 5.0, + window_minutes: Some(45), + resets_at: Some(1_650), + }), + credits: Some(CreditsSnapshot { + has_credits: true, + unlimited: false, + balance: Some("15.00".to_string()), + }), + plan_type: Some(codex_protocol::account::PlanType::Plus), + }; + state.set_rate_limits(initial.clone()); + + let update = RateLimitSnapshot { + limit_id: None, + limit_name: None, + primary: Some(RateLimitWindow { + used_percent: 35.0, + window_minutes: Some(25), + resets_at: Some(1_700), + }), + secondary: None, + credits: None, + plan_type: Some(codex_protocol::account::PlanType::Pro), + }; + state.set_rate_limits(update.clone()); + + assert_eq!( + state.latest_rate_limits, + Some(RateLimitSnapshot { + limit_id: Some("codex".to_string()), + limit_name: None, + primary: update.primary, + secondary: update.secondary, + credits: initial.credits, + plan_type: update.plan_type, + }) + ); + } + + #[test] + fn prefers_structured_content_when_present() { + let ctr = McpCallToolResult { + // Content present but should be ignored because structured_content is set. + content: vec![text_block("ignored")], + is_error: None, + structured_content: Some(json!({ + "ok": true, + "value": 42 + })), + meta: None, + }; + + let got = FunctionCallOutputPayload::from(&ctr); + let expected = FunctionCallOutputPayload { + body: FunctionCallOutputBody::Text( + serde_json::to_string(&json!({ + "ok": true, + "value": 42 + })) + .unwrap(), + ), + success: Some(true), + }; + + assert_eq!(expected, got); + } + + #[tokio::test] + async fn includes_timed_out_message() { + let exec = ExecToolCallOutput { + exit_code: 0, + stdout: StreamOutput::new(String::new()), + stderr: StreamOutput::new(String::new()), + aggregated_output: StreamOutput::new("Command output".to_string()), + duration: StdDuration::from_secs(1), + timed_out: true, + }; + let (_, turn_context) = make_session_and_context().await; + + let out = format_exec_output_str(&exec, turn_context.truncation_policy); + + assert_eq!( + out, + "command timed out after 1000 milliseconds\nCommand output" + ); + } + + #[tokio::test] + async fn turn_context_with_model_updates_model_fields() { + let (session, mut turn_context) = make_session_and_context().await; + turn_context.reasoning_effort = Some(ReasoningEffortConfig::Minimal); + let updated = turn_context + .with_model("gpt-5.1".to_string(), &session.services.models_manager) + .await; + let expected_model_info = session + .services + .models_manager + .get_model_info("gpt-5.1", updated.config.as_ref()) + .await; + + assert_eq!(updated.config.model.as_deref(), Some("gpt-5.1")); + assert_eq!(updated.collaboration_mode.model(), "gpt-5.1"); + assert_eq!(updated.model_info, expected_model_info); + assert_eq!( + updated.reasoning_effort, + Some(ReasoningEffortConfig::Medium) + ); + assert_eq!( + updated.collaboration_mode.reasoning_effort(), + Some(ReasoningEffortConfig::Medium) + ); + assert_eq!( + updated.config.model_reasoning_effort, + Some(ReasoningEffortConfig::Medium) + ); + assert_eq!( + updated.truncation_policy, + expected_model_info.truncation_policy.into() + ); + assert!(!Arc::ptr_eq( + &updated.tool_call_gate, + &turn_context.tool_call_gate + )); + } + + #[test] + fn falls_back_to_content_when_structured_is_null() { + let ctr = McpCallToolResult { + content: vec![text_block("hello"), text_block("world")], + is_error: None, + structured_content: Some(serde_json::Value::Null), + meta: None, + }; + + let got = FunctionCallOutputPayload::from(&ctr); + let expected = FunctionCallOutputPayload { + body: FunctionCallOutputBody::Text( + serde_json::to_string(&vec![text_block("hello"), text_block("world")]).unwrap(), + ), + success: Some(true), + }; + + assert_eq!(expected, got); + } + + #[test] + fn success_flag_reflects_is_error_true() { + let ctr = McpCallToolResult { + content: vec![text_block("unused")], + is_error: Some(true), + structured_content: Some(json!({ "message": "bad" })), + meta: None, + }; + + let got = FunctionCallOutputPayload::from(&ctr); + let expected = FunctionCallOutputPayload { + body: FunctionCallOutputBody::Text( + serde_json::to_string(&json!({ "message": "bad" })).unwrap(), + ), + success: Some(false), + }; + + assert_eq!(expected, got); + } + + #[test] + fn success_flag_true_with_no_error_and_content_used() { + let ctr = McpCallToolResult { + content: vec![text_block("alpha")], + is_error: Some(false), + structured_content: None, + meta: None, + }; + + let got = FunctionCallOutputPayload::from(&ctr); + let expected = FunctionCallOutputPayload { + body: FunctionCallOutputBody::Text( + serde_json::to_string(&vec![text_block("alpha")]).unwrap(), + ), + success: Some(true), + }; + + assert_eq!(expected, got); + } + + async fn wait_for_thread_rolled_back( + rx: &async_channel::Receiver, + ) -> crate::protocol::ThreadRolledBackEvent { + let deadline = StdDuration::from_secs(2); + let start = std::time::Instant::now(); + loop { + let remaining = deadline.saturating_sub(start.elapsed()); + let evt = tokio::time::timeout(remaining, rx.recv()) + .await + .expect("timeout waiting for event") + .expect("event"); + match evt.msg { + EventMsg::ThreadRolledBack(payload) => return payload, + _ => continue, + } + } + } + + async fn wait_for_thread_rollback_failed(rx: &async_channel::Receiver) -> ErrorEvent { + let deadline = StdDuration::from_secs(2); + let start = std::time::Instant::now(); + loop { + let remaining = deadline.saturating_sub(start.elapsed()); + let evt = tokio::time::timeout(remaining, rx.recv()) + .await + .expect("timeout waiting for event") + .expect("event"); + match evt.msg { + EventMsg::Error(payload) + if payload.codex_error_info == Some(CodexErrorInfo::ThreadRollbackFailed) => + { + return payload; + } + _ => continue, + } + } + } + + async fn attach_rollout_recorder(session: &Arc) -> PathBuf { + let config = session.get_config().await; + let recorder = RolloutRecorder::new( + config.as_ref(), + RolloutRecorderParams::new( + ThreadId::default(), + None, + SessionSource::Exec, + BaseInstructions::default(), + Vec::new(), + EventPersistenceMode::Limited, + ), + None, + None, + ) + .await + .expect("create rollout recorder"); + let rollout_path = recorder.rollout_path().to_path_buf(); + { + let mut rollout = session.services.rollout.lock().await; + *rollout = Some(recorder); + } + session.ensure_rollout_materialized().await; + session.flush_rollout().await; + rollout_path + } + + fn text_block(s: &str) -> serde_json::Value { + json!({ + "type": "text", + "text": s, + }) + } + + fn init_test_tracing() { + static INIT: Once = Once::new(); + INIT.call_once(|| { + let provider = SdkTracerProvider::builder().build(); + let tracer = provider.tracer("codex-core-tests"); + let subscriber = tracing_subscriber::registry() + .with(tracing_opentelemetry::layer().with_tracer(tracer)); + tracing::subscriber::set_global_default(subscriber) + .expect("global tracing subscriber should only be installed once"); + }); + } + + async fn build_test_config(codex_home: &Path) -> Config { + ConfigBuilder::default() + .codex_home(codex_home.to_path_buf()) + .build() + .await + .expect("load default test config") + } + + fn otel_manager( + conversation_id: ThreadId, + config: &Config, + model_info: &ModelInfo, + session_source: SessionSource, + ) -> OtelManager { + OtelManager::new( + conversation_id, + ModelsManager::get_model_offline_for_tests(config.model.as_deref()).as_str(), + model_info.slug.as_str(), + None, + Some("test@test.com".to_string()), + Some(TelemetryAuthMode::Chatgpt), + "test_originator".to_string(), + false, + "test".to_string(), + session_source, + ) + } + + pub(crate) async fn make_session_configuration_for_tests() -> SessionConfiguration { + let codex_home = tempfile::tempdir().expect("create temp dir"); + let config = build_test_config(codex_home.path()).await; + let config = Arc::new(config); + let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); + let model_info = + ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); + let reasoning_effort = config.model_reasoning_effort; + let collaboration_mode = CollaborationMode { + mode: ModeKind::Default, + settings: Settings { + model, + reasoning_effort, + developer_instructions: None, + }, + }; + + SessionConfiguration { + provider: config.model_provider.clone(), + collaboration_mode, + model_reasoning_summary: config.model_reasoning_summary, + developer_instructions: config.developer_instructions.clone(), + user_instructions: config.user_instructions.clone(), + service_tier: None, + personality: config.personality, + base_instructions: config + .base_instructions + .clone() + .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), + compact_prompt: config.compact_prompt.clone(), + approval_policy: config.permissions.approval_policy.clone(), + sandbox_policy: config.permissions.sandbox_policy.clone(), + windows_sandbox_level: WindowsSandboxLevel::from_config(&config), + cwd: config.cwd.clone(), + codex_home: config.codex_home.clone(), + thread_name: None, + original_config_do_not_use: Arc::clone(&config), + metrics_service_name: None, + app_server_client_name: None, + session_source: SessionSource::Exec, + dynamic_tools: Vec::new(), + persist_extended_history: false, + inherited_shell_snapshot: None, + } + } + + #[tokio::test] + async fn session_new_fails_when_zsh_fork_enabled_without_zsh_path() { + let codex_home = tempfile::tempdir().expect("create temp dir"); + let mut config = build_test_config(codex_home.path()).await; + config + .features + .enable(Feature::ShellZshFork) + .expect("test config should allow shell_zsh_fork"); + config.zsh_path = None; + let config = Arc::new(config); + + let auth_manager = + AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); + let models_manager = Arc::new(ModelsManager::new( + config.codex_home.clone(), + auth_manager.clone(), + None, + CollaborationModesConfig::default(), + )); + let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); + let model_info = + ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); + let collaboration_mode = CollaborationMode { + mode: ModeKind::Default, + settings: Settings { + model, + reasoning_effort: config.model_reasoning_effort, + developer_instructions: None, + }, + }; + let session_configuration = SessionConfiguration { + provider: config.model_provider.clone(), + collaboration_mode, + model_reasoning_summary: config.model_reasoning_summary, + developer_instructions: config.developer_instructions.clone(), + user_instructions: config.user_instructions.clone(), + service_tier: None, + personality: config.personality, + base_instructions: config + .base_instructions + .clone() + .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), + compact_prompt: config.compact_prompt.clone(), + approval_policy: config.permissions.approval_policy.clone(), + sandbox_policy: config.permissions.sandbox_policy.clone(), + windows_sandbox_level: WindowsSandboxLevel::from_config(&config), + cwd: config.cwd.clone(), + codex_home: config.codex_home.clone(), + thread_name: None, + original_config_do_not_use: Arc::clone(&config), + metrics_service_name: None, + app_server_client_name: None, + session_source: SessionSource::Exec, + dynamic_tools: Vec::new(), + persist_extended_history: false, + inherited_shell_snapshot: None, + }; + + let (tx_event, _rx_event) = async_channel::unbounded(); + let (agent_status_tx, _agent_status_rx) = watch::channel(AgentStatus::PendingInit); + let plugins_manager = Arc::new(PluginsManager::new(config.codex_home.clone())); + let mcp_manager = Arc::new(McpManager::new(Arc::clone(&plugins_manager))); + let skills_manager = Arc::new(SkillsManager::new( + config.codex_home.clone(), + Arc::clone(&plugins_manager), + )); + let result = Session::new( + session_configuration, + Arc::clone(&config), + auth_manager, + models_manager, + ExecPolicyManager::default(), + tx_event, + agent_status_tx, + InitialHistory::New, + SessionSource::Exec, + skills_manager, + plugins_manager, + mcp_manager, + Arc::new(FileWatcher::noop()), + AgentControl::default(), + ) + .await; + + let err = match result { + Ok(_) => panic!("expected startup to fail"), + Err(err) => err, + }; + let msg = format!("{err:#}"); + assert!(msg.contains("zsh fork feature enabled, but `zsh_path` is not configured")); + } + + // todo: use online model info + pub(crate) async fn make_session_and_context() -> (Session, TurnContext) { + let (tx_event, _rx_event) = async_channel::unbounded(); + let codex_home = tempfile::tempdir().expect("create temp dir"); + let config = build_test_config(codex_home.path()).await; + let config = Arc::new(config); + let conversation_id = ThreadId::default(); + let auth_manager = + AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); + let models_manager = Arc::new(ModelsManager::new( + config.codex_home.clone(), + auth_manager.clone(), + None, + CollaborationModesConfig::default(), + )); + let agent_control = AgentControl::default(); + let exec_policy = ExecPolicyManager::default(); + let (agent_status_tx, _agent_status_rx) = watch::channel(AgentStatus::PendingInit); + let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); + let model_info = + ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); + let reasoning_effort = config.model_reasoning_effort; + let collaboration_mode = CollaborationMode { + mode: ModeKind::Default, + settings: Settings { + model, + reasoning_effort, + developer_instructions: None, + }, + }; + let session_configuration = SessionConfiguration { + provider: config.model_provider.clone(), + collaboration_mode, + model_reasoning_summary: config.model_reasoning_summary, + developer_instructions: config.developer_instructions.clone(), + user_instructions: config.user_instructions.clone(), + service_tier: None, + personality: config.personality, + base_instructions: config + .base_instructions + .clone() + .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), + compact_prompt: config.compact_prompt.clone(), + approval_policy: config.permissions.approval_policy.clone(), + sandbox_policy: config.permissions.sandbox_policy.clone(), + windows_sandbox_level: WindowsSandboxLevel::from_config(&config), + cwd: config.cwd.clone(), + codex_home: config.codex_home.clone(), + thread_name: None, + original_config_do_not_use: Arc::clone(&config), + metrics_service_name: None, + app_server_client_name: None, + session_source: SessionSource::Exec, + dynamic_tools: Vec::new(), + persist_extended_history: false, + inherited_shell_snapshot: None, + }; + let per_turn_config = Session::build_per_turn_config(&session_configuration); + let model_info = ModelsManager::construct_model_info_offline_for_tests( + session_configuration.collaboration_mode.model(), + &per_turn_config, + ); + let otel_manager = otel_manager( + conversation_id, + config.as_ref(), + &model_info, + session_configuration.session_source.clone(), + ); + + let state = SessionState::new(session_configuration.clone()); + let plugins_manager = Arc::new(PluginsManager::new(config.codex_home.clone())); + let mcp_manager = Arc::new(McpManager::new(Arc::clone(&plugins_manager))); + let skills_manager = Arc::new(SkillsManager::new( + config.codex_home.clone(), + Arc::clone(&plugins_manager), + )); + let network_approval = Arc::new(NetworkApprovalService::default()); + + let file_watcher = Arc::new(FileWatcher::noop()); + let services = SessionServices { + mcp_connection_manager: Arc::new(RwLock::new( + McpConnectionManager::new_mcp_connection_manager_for_tests( + &config.permissions.approval_policy, + ), + )), + mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()), + unified_exec_manager: UnifiedExecProcessManager::new( + config.background_terminal_max_timeout, + ), + shell_zsh_path: None, + main_execve_wrapper_exe: config.main_execve_wrapper_exe.clone(), + analytics_events_client: AnalyticsEventsClient::new( + Arc::clone(&config), + Arc::clone(&auth_manager), + ), + hooks: Hooks::new(HooksConfig { + legacy_notify_argv: config.notify.clone(), + }), + rollout: Mutex::new(None), + user_shell: Arc::new(default_user_shell()), + shell_snapshot_tx: watch::channel(None).0, + show_raw_agent_reasoning: config.show_raw_agent_reasoning, + exec_policy, + auth_manager: auth_manager.clone(), + otel_manager: otel_manager.clone(), + models_manager: Arc::clone(&models_manager), + tool_approvals: Mutex::new(ApprovalStore::default()), + execve_session_approvals: RwLock::new(HashMap::new()), + skills_manager, + plugins_manager, + mcp_manager, + file_watcher, + agent_control, + network_proxy: None, + network_approval: Arc::clone(&network_approval), + state_db: None, + model_client: ModelClient::new( + Some(auth_manager.clone()), + conversation_id, + session_configuration.provider.clone(), + session_configuration.session_source.clone(), + config.model_verbosity, + ws_version_from_features(config.as_ref()), + config.features.enabled(Feature::EnableRequestCompression), + config.features.enabled(Feature::RuntimeMetrics), + Session::build_model_client_beta_features_header(config.as_ref()), + ), + }; + let js_repl = Arc::new(JsReplHandle::with_node_path( + config.js_repl_node_path.clone(), + config.js_repl_node_module_dirs.clone(), + )); + + let skills_outcome = Arc::new(services.skills_manager.skills_for_config(&per_turn_config)); + let turn_context = Session::make_turn_context( + Some(Arc::clone(&auth_manager)), + &otel_manager, + session_configuration.provider.clone(), + &session_configuration, + per_turn_config, + model_info, + None, + "turn_id".to_string(), + Arc::clone(&js_repl), + skills_outcome, + ); + + let session = Session { + conversation_id, + tx_event, + agent_status: agent_status_tx, + state: Mutex::new(state), + features: config.features.clone(), + pending_mcp_server_refresh_config: Mutex::new(None), + conversation: Arc::new(RealtimeConversationManager::new()), + active_turn: Mutex::new(None), + services, + js_repl, + next_internal_sub_id: AtomicU64::new(0), + }; + + (session, turn_context) + } + + #[tokio::test] + async fn submit_with_id_captures_current_span_trace_context() { + let (session, _turn_context) = make_session_and_context().await; + let (tx_sub, rx_sub) = async_channel::bounded(1); + let (_tx_event, rx_event) = async_channel::unbounded(); + let (_agent_status_tx, agent_status) = watch::channel(AgentStatus::PendingInit); + let codex = Codex { + tx_sub, + rx_event, + agent_status, + session: Arc::new(session), + }; + + init_test_tracing(); + + let request_parent = W3cTraceContext { + traceparent: Some("00-00000000000000000000000000000011-0000000000000022-01".into()), + tracestate: Some("vendor=value".into()), + }; + let request_span = info_span!("app_server.request"); + assert!(set_parent_from_w3c_trace_context( + &request_span, + &request_parent + )); + + let expected_trace = async { + let expected_trace = + current_span_w3c_trace_context().expect("current span should have trace context"); + codex + .submit_with_id(Submission { + id: "sub-1".into(), + op: Op::Interrupt, + trace: None, + }) + .await + .expect("submit should succeed"); + expected_trace + } + .instrument(request_span) + .await; + + let submitted = rx_sub.recv().await.expect("submission"); + assert_eq!(submitted.trace, Some(expected_trace)); + } + + #[tokio::test] + async fn new_default_turn_captures_current_span_trace_id() { + let (session, _turn_context) = make_session_and_context().await; + + init_test_tracing(); + + let request_parent = W3cTraceContext { + traceparent: Some("00-00000000000000000000000000000011-0000000000000022-01".into()), + tracestate: Some("vendor=value".into()), + }; + let request_span = info_span!("app_server.request"); + assert!(set_parent_from_w3c_trace_context( + &request_span, + &request_parent + )); + + let turn_context_item = async { + let expected_trace_id = Span::current() + .context() + .span() + .span_context() + .trace_id() + .to_string(); + let turn_context = session.new_default_turn().await; + let turn_context_item = turn_context.to_turn_context_item(); + assert_eq!(turn_context_item.trace_id, Some(expected_trace_id)); + turn_context_item + } + .instrument(request_span) + .await; + + assert_eq!( + turn_context_item.trace_id.as_deref(), + Some("00000000000000000000000000000011") + ); + } + + #[test] + fn submission_dispatch_span_prefers_submission_trace_context() { + init_test_tracing(); + + let ambient_parent = W3cTraceContext { + traceparent: Some("00-00000000000000000000000000000033-0000000000000044-01".into()), + tracestate: None, + }; + let ambient_span = info_span!("ambient"); + assert!(set_parent_from_w3c_trace_context( + &ambient_span, + &ambient_parent + )); + + let submission_trace = W3cTraceContext { + traceparent: Some("00-00000000000000000000000000000055-0000000000000066-01".into()), + tracestate: Some("vendor=value".into()), + }; + let dispatch_span = ambient_span.in_scope(|| { + submission_dispatch_span(&Submission { + id: "sub-1".into(), + op: Op::Interrupt, + trace: Some(submission_trace), + }) + }); + + let trace_id = dispatch_span.context().span().span_context().trace_id(); + assert_eq!( + trace_id, + TraceId::from_hex("00000000000000000000000000000055").expect("trace id") + ); + } + + #[test] + fn submission_dispatch_span_uses_debug_for_realtime_audio() { + init_test_tracing(); + + let dispatch_span = submission_dispatch_span(&Submission { + id: "sub-1".into(), + op: Op::RealtimeConversationAudio(ConversationAudioParams { + frame: RealtimeAudioFrame { + data: "ZmFrZQ==".into(), + sample_rate: 16_000, + num_channels: 1, + samples_per_channel: Some(160), + }, + }), + trace: None, + }); + + assert_eq!( + dispatch_span.metadata().expect("span metadata").level(), + &tracing::Level::DEBUG + ); + } + + #[tokio::test] + async fn spawn_task_turn_span_inherits_dispatch_trace_context() { + struct TraceCaptureTask { + captured_trace: Arc>>, + } + + #[async_trait::async_trait] + impl SessionTask for TraceCaptureTask { + fn kind(&self) -> TaskKind { + TaskKind::Regular + } + + fn span_name(&self) -> &'static str { + "session_task.trace_capture" + } + + async fn run( + self: Arc, + _session: Arc, + _ctx: Arc, + _input: Vec, + _cancellation_token: CancellationToken, + ) -> Option { + let mut trace = self + .captured_trace + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + *trace = current_span_w3c_trace_context(); + None + } + } + + init_test_tracing(); + + let request_parent = W3cTraceContext { + traceparent: Some("00-00000000000000000000000000000011-0000000000000022-01".into()), + tracestate: Some("vendor=value".into()), + }; + let request_span = tracing::info_span!("app_server.request"); + assert!(set_parent_from_w3c_trace_context( + &request_span, + &request_parent + )); + + let submission_trace = async { + current_span_w3c_trace_context().expect("request span should have trace context") + } + .instrument(request_span) + .await; + + let dispatch_span = submission_dispatch_span(&Submission { + id: "sub-1".into(), + op: Op::Interrupt, + trace: Some(submission_trace.clone()), + }); + let dispatch_span_id = dispatch_span.context().span().span_context().span_id(); + + let (sess, tc, rx) = make_session_and_context_with_rx().await; + let captured_trace = Arc::new(std::sync::Mutex::new(None)); + + async { + sess.spawn_task( + Arc::clone(&tc), + vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }], + TraceCaptureTask { + captured_trace: Arc::clone(&captured_trace), + }, + ) + .await; + } + .instrument(dispatch_span) + .await; + + let evt = tokio::time::timeout(StdDuration::from_secs(2), rx.recv()) + .await + .expect("timeout waiting for turn completion") + .expect("event"); + assert!(matches!(evt.msg, EventMsg::TurnComplete(_))); + + let task_trace = captured_trace + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone() + .expect("turn task should capture the current span trace context"); + let submission_context = + codex_otel::context_from_w3c_trace_context(&submission_trace).expect("submission"); + let task_context = + codex_otel::context_from_w3c_trace_context(&task_trace).expect("task trace"); + + assert_eq!( + task_context.span().span_context().trace_id(), + submission_context.span().span_context().trace_id() + ); + assert_ne!( + task_context.span().span_context().span_id(), + dispatch_span_id + ); + } + + pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx( + dynamic_tools: Vec, + ) -> ( + Arc, + Arc, + async_channel::Receiver, + ) { + let (tx_event, rx_event) = async_channel::unbounded(); + let codex_home = tempfile::tempdir().expect("create temp dir"); + let config = build_test_config(codex_home.path()).await; + let config = Arc::new(config); + let conversation_id = ThreadId::default(); + let auth_manager = + AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); + let models_manager = Arc::new(ModelsManager::new( + config.codex_home.clone(), + auth_manager.clone(), + None, + CollaborationModesConfig::default(), + )); + let agent_control = AgentControl::default(); + let exec_policy = ExecPolicyManager::default(); + let (agent_status_tx, _agent_status_rx) = watch::channel(AgentStatus::PendingInit); + let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); + let model_info = + ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); + let reasoning_effort = config.model_reasoning_effort; + let collaboration_mode = CollaborationMode { + mode: ModeKind::Default, + settings: Settings { + model, + reasoning_effort, + developer_instructions: None, + }, + }; + let session_configuration = SessionConfiguration { + provider: config.model_provider.clone(), + collaboration_mode, + model_reasoning_summary: config.model_reasoning_summary, + developer_instructions: config.developer_instructions.clone(), + user_instructions: config.user_instructions.clone(), + service_tier: None, + personality: config.personality, + base_instructions: config + .base_instructions + .clone() + .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), + compact_prompt: config.compact_prompt.clone(), + approval_policy: config.permissions.approval_policy.clone(), + sandbox_policy: config.permissions.sandbox_policy.clone(), + windows_sandbox_level: WindowsSandboxLevel::from_config(&config), + cwd: config.cwd.clone(), + codex_home: config.codex_home.clone(), + thread_name: None, + original_config_do_not_use: Arc::clone(&config), + metrics_service_name: None, + app_server_client_name: None, + session_source: SessionSource::Exec, + dynamic_tools, + persist_extended_history: false, + inherited_shell_snapshot: None, + }; + let per_turn_config = Session::build_per_turn_config(&session_configuration); + let model_info = ModelsManager::construct_model_info_offline_for_tests( + session_configuration.collaboration_mode.model(), + &per_turn_config, + ); + let otel_manager = otel_manager( + conversation_id, + config.as_ref(), + &model_info, + session_configuration.session_source.clone(), + ); + + let state = SessionState::new(session_configuration.clone()); + let plugins_manager = Arc::new(PluginsManager::new(config.codex_home.clone())); + let mcp_manager = Arc::new(McpManager::new(Arc::clone(&plugins_manager))); + let skills_manager = Arc::new(SkillsManager::new( + config.codex_home.clone(), + Arc::clone(&plugins_manager), + )); + let network_approval = Arc::new(NetworkApprovalService::default()); + + let file_watcher = Arc::new(FileWatcher::noop()); + let services = SessionServices { + mcp_connection_manager: Arc::new(RwLock::new( + McpConnectionManager::new_mcp_connection_manager_for_tests( + &config.permissions.approval_policy, + ), + )), + mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()), + unified_exec_manager: UnifiedExecProcessManager::new( + config.background_terminal_max_timeout, + ), + shell_zsh_path: None, + main_execve_wrapper_exe: config.main_execve_wrapper_exe.clone(), + analytics_events_client: AnalyticsEventsClient::new( + Arc::clone(&config), + Arc::clone(&auth_manager), + ), + hooks: Hooks::new(HooksConfig { + legacy_notify_argv: config.notify.clone(), + }), + rollout: Mutex::new(None), + user_shell: Arc::new(default_user_shell()), + shell_snapshot_tx: watch::channel(None).0, + show_raw_agent_reasoning: config.show_raw_agent_reasoning, + exec_policy, + auth_manager: Arc::clone(&auth_manager), + otel_manager: otel_manager.clone(), + models_manager: Arc::clone(&models_manager), + tool_approvals: Mutex::new(ApprovalStore::default()), + execve_session_approvals: RwLock::new(HashMap::new()), + skills_manager, + plugins_manager, + mcp_manager, + file_watcher, + agent_control, + network_proxy: None, + network_approval: Arc::clone(&network_approval), + state_db: None, + model_client: ModelClient::new( + Some(Arc::clone(&auth_manager)), + conversation_id, + session_configuration.provider.clone(), + session_configuration.session_source.clone(), + config.model_verbosity, + ws_version_from_features(config.as_ref()), + config.features.enabled(Feature::EnableRequestCompression), + config.features.enabled(Feature::RuntimeMetrics), + Session::build_model_client_beta_features_header(config.as_ref()), + ), + }; + let js_repl = Arc::new(JsReplHandle::with_node_path( + config.js_repl_node_path.clone(), + config.js_repl_node_module_dirs.clone(), + )); + + let skills_outcome = Arc::new(services.skills_manager.skills_for_config(&per_turn_config)); + let turn_context = Arc::new(Session::make_turn_context( + Some(Arc::clone(&auth_manager)), + &otel_manager, + session_configuration.provider.clone(), + &session_configuration, + per_turn_config, + model_info, + None, + "turn_id".to_string(), + Arc::clone(&js_repl), + skills_outcome, + )); + + let session = Arc::new(Session { + conversation_id, + tx_event, + agent_status: agent_status_tx, + state: Mutex::new(state), + features: config.features.clone(), + pending_mcp_server_refresh_config: Mutex::new(None), + conversation: Arc::new(RealtimeConversationManager::new()), + active_turn: Mutex::new(None), + services, + js_repl, + next_internal_sub_id: AtomicU64::new(0), + }); + + (session, turn_context, rx_event) + } + + // Like make_session_and_context, but returns Arc and the event receiver + // so tests can assert on emitted events. + pub(crate) async fn make_session_and_context_with_rx() -> ( + Arc, + Arc, + async_channel::Receiver, + ) { + make_session_and_context_with_dynamic_tools_and_rx(Vec::new()).await + } + + #[tokio::test] + async fn refresh_mcp_servers_is_deferred_until_next_turn() { + let (session, turn_context) = make_session_and_context().await; + let old_token = session.mcp_startup_cancellation_token().await; + assert!(!old_token.is_cancelled()); + + let mcp_oauth_credentials_store_mode = + serde_json::to_value(OAuthCredentialsStoreMode::Auto).expect("serialize store mode"); + let refresh_config = McpServerRefreshConfig { + mcp_servers: json!({}), + mcp_oauth_credentials_store_mode, + }; + { + let mut guard = session.pending_mcp_server_refresh_config.lock().await; + *guard = Some(refresh_config); + } + + assert!(!old_token.is_cancelled()); + assert!( + session + .pending_mcp_server_refresh_config + .lock() + .await + .is_some() + ); + + session + .refresh_mcp_servers_if_requested(&turn_context) + .await; + + assert!(old_token.is_cancelled()); + assert!( + session + .pending_mcp_server_refresh_config + .lock() + .await + .is_none() + ); + let new_token = session.mcp_startup_cancellation_token().await; + assert!(!new_token.is_cancelled()); + } + + #[tokio::test] + async fn record_model_warning_appends_user_message() { + let (mut session, turn_context) = make_session_and_context().await; + let features = crate::features::Features::with_defaults().into(); + session.features = features; + + session + .record_model_warning("too many unified exec processes", &turn_context) + .await; + + let history = session.clone_history().await; + let history_items = history.raw_items(); + let last = history_items.last().expect("warning recorded"); + + match last { + ResponseItem::Message { role, content, .. } => { + assert_eq!(role, "user"); + assert_eq!( + content, + &vec![ContentItem::InputText { + text: "Warning: too many unified exec processes".to_string(), + }] + ); + } + other => panic!("expected user message, got {other:?}"), + } + } + + #[tokio::test] + async fn spawn_task_does_not_update_previous_turn_settings_for_non_run_turn_tasks() { + let (sess, tc, _rx) = make_session_and_context_with_rx().await; + sess.set_previous_turn_settings(None).await; + let input = vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }]; + + sess.spawn_task( + Arc::clone(&tc), + input, + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: true, + }, + ) + .await; + + sess.abort_all_tasks(TurnAbortReason::Interrupted).await; + assert_eq!(sess.previous_turn_settings().await, None); + } + + #[tokio::test] + async fn build_settings_update_items_emits_environment_item_for_network_changes() { + let (session, previous_context) = make_session_and_context().await; + let previous_context = Arc::new(previous_context); + let mut current_context = previous_context + .with_model( + previous_context.model_info.slug.clone(), + &session.services.models_manager, + ) + .await; + + let mut config = (*current_context.config).clone(); + let mut requirements = config.config_layer_stack.requirements().clone(); + requirements.network = Some(Sourced::new( + NetworkConstraints { + allowed_domains: Some(vec!["api.example.com".to_string()]), + denied_domains: Some(vec!["blocked.example.com".to_string()]), + ..Default::default() + }, + RequirementSource::CloudRequirements, + )); + let layers = config + .config_layer_stack + .get_layers(ConfigLayerStackOrdering::LowestPrecedenceFirst, true) + .into_iter() + .cloned() + .collect(); + config.config_layer_stack = ConfigLayerStack::new( + layers, + requirements, + config.config_layer_stack.requirements_toml().clone(), + ) + .expect("rebuild config layer stack with network requirements"); + current_context.config = Arc::new(config); + + let reference_context_item = previous_context.to_turn_context_item(); + let update_items = session + .build_settings_update_items(Some(&reference_context_item), ¤t_context) + .await; + + let environment_update = update_items + .iter() + .find_map(|item| match item { + ResponseItem::Message { role, content, .. } if role == "user" => { + let [ContentItem::InputText { text }] = content.as_slice() else { + return None; + }; + text.contains("").then_some(text) + } + _ => None, + }) + .expect("environment update item should be emitted"); + assert!(environment_update.contains("")); + assert!(environment_update.contains("api.example.com")); + assert!(environment_update.contains("blocked.example.com")); + } + + #[tokio::test] + async fn build_settings_update_items_emits_environment_item_for_time_changes() { + let (session, previous_context) = make_session_and_context().await; + let previous_context = Arc::new(previous_context); + let mut current_context = previous_context + .with_model( + previous_context.model_info.slug.clone(), + &session.services.models_manager, + ) + .await; + current_context.current_date = Some("2026-02-27".to_string()); + current_context.timezone = Some("Europe/Berlin".to_string()); + + let reference_context_item = previous_context.to_turn_context_item(); + let update_items = session + .build_settings_update_items(Some(&reference_context_item), ¤t_context) + .await; + + let environment_update = update_items + .iter() + .find_map(|item| match item { + ResponseItem::Message { role, content, .. } if role == "user" => { + let [ContentItem::InputText { text }] = content.as_slice() else { + return None; + }; + text.contains("").then_some(text) + } + _ => None, + }) + .expect("environment update item should be emitted"); + assert!(environment_update.contains("2026-02-27")); + assert!(environment_update.contains("Europe/Berlin")); + } + + #[tokio::test] + async fn build_settings_update_items_emits_realtime_start_when_session_becomes_live() { + let (session, previous_context) = make_session_and_context().await; + let previous_context = Arc::new(previous_context); + let mut current_context = previous_context + .with_model( + previous_context.model_info.slug.clone(), + &session.services.models_manager, + ) + .await; + current_context.realtime_active = true; + + let update_items = session + .build_settings_update_items( + Some(&previous_context.to_turn_context_item()), + ¤t_context, + ) + .await; + + let developer_texts = developer_input_texts(&update_items); + assert!( + developer_texts + .iter() + .any(|text| text.contains("")), + "expected a realtime start update, got {developer_texts:?}" + ); + } + + #[tokio::test] + async fn build_settings_update_items_emits_realtime_end_when_session_stops_being_live() { + let (session, mut previous_context) = make_session_and_context().await; + previous_context.realtime_active = true; + let mut current_context = previous_context + .with_model( + previous_context.model_info.slug.clone(), + &session.services.models_manager, + ) + .await; + current_context.realtime_active = false; + + let update_items = session + .build_settings_update_items( + Some(&previous_context.to_turn_context_item()), + ¤t_context, + ) + .await; + + let developer_texts = developer_input_texts(&update_items); + assert!( + developer_texts + .iter() + .any(|text| text.contains("Reason: inactive")), + "expected a realtime end update, got {developer_texts:?}" + ); + } + + #[tokio::test] + async fn build_settings_update_items_uses_previous_turn_settings_for_realtime_end() { + let (session, previous_context) = make_session_and_context().await; + let mut previous_context_item = previous_context.to_turn_context_item(); + previous_context_item.realtime_active = None; + let previous_turn_settings = PreviousTurnSettings { + model: previous_context.model_info.slug.clone(), + realtime_active: Some(true), + }; + let mut current_context = previous_context + .with_model( + previous_context.model_info.slug.clone(), + &session.services.models_manager, + ) + .await; + current_context.realtime_active = false; + + session + .set_previous_turn_settings(Some(previous_turn_settings)) + .await; + let update_items = session + .build_settings_update_items(Some(&previous_context_item), ¤t_context) + .await; + + let developer_texts = developer_input_texts(&update_items); + assert!( + developer_texts + .iter() + .any(|text| text.contains("Reason: inactive")), + "expected a realtime end update from previous turn settings, got {developer_texts:?}" + ); + } + + #[tokio::test] + async fn build_initial_context_uses_previous_realtime_state() { + let (session, mut turn_context) = make_session_and_context().await; + turn_context.realtime_active = true; + + let initial_context = session.build_initial_context(&turn_context).await; + let developer_texts = developer_input_texts(&initial_context); + assert!( + developer_texts + .iter() + .any(|text| text.contains("")), + "expected initial context to describe active realtime state, got {developer_texts:?}" + ); + + let previous_context_item = turn_context.to_turn_context_item(); + { + let mut state = session.state.lock().await; + state.set_reference_context_item(Some(previous_context_item)); + } + let resumed_context = session.build_initial_context(&turn_context).await; + let resumed_developer_texts = developer_input_texts(&resumed_context); + assert!( + !resumed_developer_texts + .iter() + .any(|text| text.contains("")), + "did not expect a duplicate realtime update, got {resumed_developer_texts:?}" + ); + } + + #[tokio::test] + async fn build_initial_context_uses_previous_turn_settings_for_realtime_end() { + let (session, turn_context) = make_session_and_context().await; + let previous_turn_settings = PreviousTurnSettings { + model: turn_context.model_info.slug.clone(), + realtime_active: Some(true), + }; + + session + .set_previous_turn_settings(Some(previous_turn_settings)) + .await; + let initial_context = session.build_initial_context(&turn_context).await; + let developer_texts = developer_input_texts(&initial_context); + assert!( + developer_texts + .iter() + .any(|text| text.contains("Reason: inactive")), + "expected initial context to describe an ended realtime session, got {developer_texts:?}" + ); + } + + #[tokio::test] + async fn build_initial_context_restates_realtime_start_when_reference_context_is_missing() { + let (session, mut turn_context) = make_session_and_context().await; + turn_context.realtime_active = true; + let previous_turn_settings = PreviousTurnSettings { + model: turn_context.model_info.slug.clone(), + realtime_active: Some(true), + }; + + session + .set_previous_turn_settings(Some(previous_turn_settings)) + .await; + let initial_context = session.build_initial_context(&turn_context).await; + let developer_texts = developer_input_texts(&initial_context); + assert!( + developer_texts + .iter() + .any(|text| text.contains("")), + "expected initial context to restate active realtime when the reference context is missing, got {developer_texts:?}" + ); + } + + #[tokio::test] + async fn record_context_updates_and_set_reference_context_item_injects_full_context_when_baseline_missing() + { + let (session, turn_context) = make_session_and_context().await; + session + .record_context_updates_and_set_reference_context_item(&turn_context) + .await; + let history = session.clone_history().await; + let initial_context = session.build_initial_context(&turn_context).await; + assert_eq!(history.raw_items().to_vec(), initial_context); + + let current_context = session.reference_context_item().await; + assert_eq!( + serde_json::to_value(current_context).expect("serialize current context item"), + serde_json::to_value(Some(turn_context.to_turn_context_item())) + .expect("serialize expected context item") + ); + } + + #[tokio::test] + async fn record_context_updates_and_set_reference_context_item_reinjects_full_context_after_clear() + { + let (session, turn_context) = make_session_and_context().await; + let compacted_summary = ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: format!("{}\nsummary", crate::compact::SUMMARY_PREFIX), + }], + end_turn: None, + phase: None, + }; + session + .record_into_history(std::slice::from_ref(&compacted_summary), &turn_context) + .await; + session + .record_context_updates_and_set_reference_context_item(&turn_context) + .await; + { + let mut state = session.state.lock().await; + state.set_reference_context_item(None); + } + session + .replace_history(vec![compacted_summary.clone()], None) + .await; + + session + .record_context_updates_and_set_reference_context_item(&turn_context) + .await; + + let history = session.clone_history().await; + let mut expected_history = vec![compacted_summary]; + expected_history.extend(session.build_initial_context(&turn_context).await); + assert_eq!(history.raw_items().to_vec(), expected_history); + } + + #[tokio::test] + async fn record_context_updates_and_set_reference_context_item_persists_baseline_without_emitting_diffs() + { + let (session, previous_context) = make_session_and_context().await; + let next_model = if previous_context.model_info.slug == "gpt-5.1" { + "gpt-5" + } else { + "gpt-5.1" + }; + let turn_context = previous_context + .with_model(next_model.to_string(), &session.services.models_manager) + .await; + let previous_context_item = previous_context.to_turn_context_item(); + { + let mut state = session.state.lock().await; + state.set_reference_context_item(Some(previous_context_item.clone())); + } + let config = session.get_config().await; + let recorder = RolloutRecorder::new( + config.as_ref(), + RolloutRecorderParams::new( + ThreadId::default(), + None, + SessionSource::Exec, + BaseInstructions::default(), + Vec::new(), + EventPersistenceMode::Limited, + ), + None, + None, + ) + .await + .expect("create rollout recorder"); + let rollout_path = recorder.rollout_path().to_path_buf(); + { + let mut rollout = session.services.rollout.lock().await; + *rollout = Some(recorder); + } + + let update_items = session + .build_settings_update_items(Some(&previous_context_item), &turn_context) + .await; + assert_eq!(update_items, Vec::new()); + + session + .record_context_updates_and_set_reference_context_item(&turn_context) + .await; + + assert_eq!( + session.clone_history().await.raw_items().to_vec(), + Vec::new() + ); + assert_eq!( + serde_json::to_value(session.reference_context_item().await) + .expect("serialize current context item"), + serde_json::to_value(Some(turn_context.to_turn_context_item())) + .expect("serialize expected context item") + ); + session.ensure_rollout_materialized().await; + session.flush_rollout().await; + + let InitialHistory::Resumed(resumed) = RolloutRecorder::get_rollout_history(&rollout_path) + .await + .expect("read rollout history") + else { + panic!("expected resumed rollout history"); + }; + let persisted_turn_context = resumed.history.iter().find_map(|item| match item { + RolloutItem::TurnContext(ctx) => Some(ctx.clone()), + _ => None, + }); + assert_eq!( + serde_json::to_value(persisted_turn_context) + .expect("serialize persisted turn context item"), + serde_json::to_value(Some(turn_context.to_turn_context_item())) + .expect("serialize expected turn context item") + ); + } + + #[tokio::test] + async fn build_initial_context_prepends_model_switch_message() { + let (session, turn_context) = make_session_and_context().await; + let previous_turn_settings = PreviousTurnSettings { + model: "previous-regular-model".to_string(), + realtime_active: None, + }; + + session + .set_previous_turn_settings(Some(previous_turn_settings)) + .await; + let initial_context = session.build_initial_context(&turn_context).await; + + let ResponseItem::Message { role, content, .. } = &initial_context[0] else { + panic!("expected developer message"); + }; + assert_eq!(role, "developer"); + let [ContentItem::InputText { text }, ..] = content.as_slice() else { + panic!("expected developer text"); + }; + assert!(text.contains("")); + } + + #[tokio::test] + async fn record_context_updates_and_set_reference_context_item_persists_full_reinjection_to_rollout() + { + let (session, previous_context) = make_session_and_context().await; + let next_model = if previous_context.model_info.slug == "gpt-5.1" { + "gpt-5" + } else { + "gpt-5.1" + }; + let turn_context = previous_context + .with_model(next_model.to_string(), &session.services.models_manager) + .await; + let config = session.get_config().await; + let recorder = RolloutRecorder::new( + config.as_ref(), + RolloutRecorderParams::new( + ThreadId::default(), + None, + SessionSource::Exec, + BaseInstructions::default(), + Vec::new(), + EventPersistenceMode::Limited, + ), + None, + None, + ) + .await + .expect("create rollout recorder"); + let rollout_path = recorder.rollout_path().to_path_buf(); + { + let mut rollout = session.services.rollout.lock().await; + *rollout = Some(recorder); + } + + session + .persist_rollout_items(&[RolloutItem::EventMsg(EventMsg::UserMessage( + UserMessageEvent { + message: "seed rollout".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + }, + ))]) + .await; + { + let mut state = session.state.lock().await; + state.set_reference_context_item(None); + } + + session + .set_previous_turn_settings(Some(PreviousTurnSettings { + model: previous_context.model_info.slug.clone(), + realtime_active: Some(previous_context.realtime_active), + })) + .await; + session + .record_context_updates_and_set_reference_context_item(&turn_context) + .await; + session.ensure_rollout_materialized().await; + session.flush_rollout().await; + + let InitialHistory::Resumed(resumed) = RolloutRecorder::get_rollout_history(&rollout_path) + .await + .expect("read rollout history") + else { + panic!("expected resumed rollout history"); + }; + let persisted_turn_context = resumed.history.iter().find_map(|item| match item { + RolloutItem::TurnContext(ctx) => Some(ctx.clone()), + _ => None, + }); + + assert_eq!( + serde_json::to_value(persisted_turn_context) + .expect("serialize persisted turn context item"), + serde_json::to_value(Some(turn_context.to_turn_context_item())) + .expect("serialize expected turn context item") + ); + } + + #[tokio::test] + async fn run_user_shell_command_does_not_set_reference_context_item() { + let (session, _turn_context, rx) = make_session_and_context_with_rx().await; + { + let mut state = session.state.lock().await; + state.set_reference_context_item(None); + } + + handlers::run_user_shell_command(&session, "sub-id".to_string(), "echo shell".to_string()) + .await; + + let deadline = StdDuration::from_secs(15); + let start = std::time::Instant::now(); + loop { + let remaining = deadline.saturating_sub(start.elapsed()); + let evt = tokio::time::timeout(remaining, rx.recv()) + .await + .expect("timeout waiting for event") + .expect("event"); + if matches!(evt.msg, EventMsg::TurnComplete(_)) { + break; + } + } + + assert!( + session.reference_context_item().await.is_none(), + "standalone shell tasks should not mutate previous context" + ); + } + + #[derive(Clone, Copy)] + struct NeverEndingTask { + kind: TaskKind, + listen_to_cancellation_token: bool, + } + + #[async_trait::async_trait] + impl SessionTask for NeverEndingTask { + fn kind(&self) -> TaskKind { + self.kind + } + + fn span_name(&self) -> &'static str { + "session_task.never_ending" + } + + async fn run( + self: Arc, + _session: Arc, + _ctx: Arc, + _input: Vec, + cancellation_token: CancellationToken, + ) -> Option { + if self.listen_to_cancellation_token { + cancellation_token.cancelled().await; + return None; + } + loop { + sleep(Duration::from_secs(60)).await; + } + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + #[test_log::test] + async fn abort_regular_task_emits_turn_aborted_only() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + let input = vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }]; + sess.spawn_task( + Arc::clone(&tc), + input, + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: false, + }, + ) + .await; + + sess.abort_all_tasks(TurnAbortReason::Interrupted).await; + + // Interrupts persist a model-visible `` marker into history, but there is no + // separate client-visible event for that marker (only `EventMsg::TurnAborted`). + let evt = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("timeout waiting for event") + .expect("event"); + match evt.msg { + EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason), + other => panic!("unexpected event: {other:?}"), + } + // No extra events should be emitted after an abort. + assert!(rx.try_recv().is_err()); + } + + #[tokio::test] + async fn abort_gracefully_emits_turn_aborted_only() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + let input = vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }]; + sess.spawn_task( + Arc::clone(&tc), + input, + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: true, + }, + ) + .await; + + sess.abort_all_tasks(TurnAbortReason::Interrupted).await; + + // Even if tasks handle cancellation gracefully, interrupts still result in `TurnAborted` + // being the only client-visible signal. + let evt = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("timeout waiting for event") + .expect("event"); + match evt.msg { + EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason), + other => panic!("unexpected event: {other:?}"), + } + // No extra events should be emitted after an abort. + assert!(rx.try_recv().is_err()); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn task_finish_emits_turn_item_lifecycle_for_leftover_pending_user_input() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + let input = vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }]; + sess.spawn_task( + Arc::clone(&tc), + input, + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: false, + }, + ) + .await; + + while rx.try_recv().is_ok() {} + + sess.inject_response_items(vec![ResponseInputItem::Message { + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "late pending input".to_string(), + }], + }]) + .await + .expect("inject pending input into active turn"); + + sess.on_task_finished(Arc::clone(&tc), None).await; + + let history = sess.clone_history().await; + let expected = ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "late pending input".to_string(), + }], + end_turn: None, + phase: None, + }; + assert!( + history.raw_items().iter().any(|item| item == &expected), + "expected pending input to be persisted into history on turn completion" + ); + + let first = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("expected raw response item event") + .expect("channel open"); + assert!(matches!(first.msg, EventMsg::RawResponseItem(_))); + + let second = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("expected item started event") + .expect("channel open"); + assert!(matches!( + second.msg, + EventMsg::ItemStarted(ItemStartedEvent { + item: TurnItem::UserMessage(UserMessageItem { content, .. }), + .. + }) if content == vec![UserInput::Text { + text: "late pending input".to_string(), + text_elements: Vec::new(), + }] + )); + + let third = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("expected item completed event") + .expect("channel open"); + assert!(matches!( + third.msg, + EventMsg::ItemCompleted(ItemCompletedEvent { + item: TurnItem::UserMessage(UserMessageItem { content, .. }), + .. + }) if content == vec![UserInput::Text { + text: "late pending input".to_string(), + text_elements: Vec::new(), + }] + )); + + let fourth = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("expected legacy user message event") + .expect("channel open"); + assert!(matches!( + fourth.msg, + EventMsg::UserMessage(UserMessageEvent { + message, + images, + text_elements, + local_images, + }) if message == "late pending input" + && images == Some(Vec::new()) + && text_elements.is_empty() + && local_images.is_empty() + )); + + let fifth = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("expected turn complete event") + .expect("channel open"); + assert!(matches!( + fifth.msg, + EventMsg::TurnComplete(TurnCompleteEvent { + turn_id, + last_agent_message: None, + }) if turn_id == tc.sub_id + )); + } + + #[tokio::test] + async fn steer_input_requires_active_turn() { + let (sess, _tc, _rx) = make_session_and_context_with_rx().await; + let input = vec![UserInput::Text { + text: "steer".to_string(), + text_elements: Vec::new(), + }]; + + let err = sess + .steer_input(input, None) + .await + .expect_err("steering without active turn should fail"); + + assert!(matches!(err, SteerInputError::NoActiveTurn(_))); + } + + #[tokio::test] + async fn steer_input_enforces_expected_turn_id() { + let (sess, tc, _rx) = make_session_and_context_with_rx().await; + let input = vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }]; + sess.spawn_task( + Arc::clone(&tc), + input, + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: false, + }, + ) + .await; + + let steer_input = vec![UserInput::Text { + text: "steer".to_string(), + text_elements: Vec::new(), + }]; + let err = sess + .steer_input(steer_input, Some("different-turn-id")) + .await + .expect_err("mismatched expected turn id should fail"); + + match err { + SteerInputError::ExpectedTurnMismatch { expected, actual } => { + assert_eq!( + (expected, actual), + ("different-turn-id".to_string(), tc.sub_id.clone()) + ); + } + other => panic!("unexpected error: {other:?}"), + } + } + + #[tokio::test] + async fn steer_input_returns_active_turn_id() { + let (sess, tc, _rx) = make_session_and_context_with_rx().await; + let input = vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }]; + sess.spawn_task( + Arc::clone(&tc), + input, + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: false, + }, + ) + .await; + + let steer_input = vec![UserInput::Text { + text: "steer".to_string(), + text_elements: Vec::new(), + }]; + let turn_id = sess + .steer_input(steer_input, Some(&tc.sub_id)) + .await + .expect("steering with matching expected turn id should succeed"); + + assert_eq!(turn_id, tc.sub_id); + assert!(sess.has_pending_input().await); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn abort_review_task_emits_exited_then_aborted_and_records_history() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + let input = vec![UserInput::Text { + text: "start review".to_string(), + text_elements: Vec::new(), + }]; + sess.spawn_task(Arc::clone(&tc), input, ReviewTask::new()) + .await; + + sess.abort_all_tasks(TurnAbortReason::Interrupted).await; + + // Aborting a review task should exit review mode before surfacing the abort to the client. + // We scan for these events (rather than relying on fixed ordering) since unrelated events + // may interleave. + let mut exited_review_mode_idx = None; + let mut turn_aborted_idx = None; + let mut idx = 0usize; + let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(3); + while tokio::time::Instant::now() < deadline { + let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); + let evt = tokio::time::timeout(remaining, rx.recv()) + .await + .expect("timeout waiting for event") + .expect("event"); + let event_idx = idx; + idx = idx.saturating_add(1); + match evt.msg { + EventMsg::ExitedReviewMode(ev) => { + assert!(ev.review_output.is_none()); + exited_review_mode_idx = Some(event_idx); + } + EventMsg::TurnAborted(ev) => { + assert_eq!(TurnAbortReason::Interrupted, ev.reason); + turn_aborted_idx = Some(event_idx); + break; + } + _ => {} + } + } + assert!( + exited_review_mode_idx.is_some(), + "expected ExitedReviewMode after abort" + ); + assert!( + turn_aborted_idx.is_some(), + "expected TurnAborted after abort" + ); + assert!( + exited_review_mode_idx.unwrap() < turn_aborted_idx.unwrap(), + "expected ExitedReviewMode before TurnAborted" + ); + + let history = sess.clone_history().await; + // The `` marker is silent in the event stream, so verify it is still + // recorded in history for the model. + assert!( + history.raw_items().iter().any(|item| { + let ResponseItem::Message { role, content, .. } = item else { + return false; + }; + if role != "user" { + return false; + } + content.iter().any(|content_item| { + let ContentItem::InputText { text } = content_item else { + return false; + }; + text.contains(crate::contextual_user_message::TURN_ABORTED_OPEN_TAG) + }) + }), + "expected a model-visible turn aborted marker in history after interrupt" + ); + } + + #[tokio::test] + async fn fatal_tool_error_stops_turn_and_reports_error() { + let (session, turn_context, _rx) = make_session_and_context_with_rx().await; + let tools = { + session + .services + .mcp_connection_manager + .read() + .await + .list_all_tools() + .await + }; + let app_tools = Some(tools.clone()); + let router = ToolRouter::from_config( + &turn_context.tools_config, + Some( + tools + .into_iter() + .map(|(name, tool)| (name, tool.tool)) + .collect(), + ), + app_tools, + turn_context.dynamic_tools.as_slice(), + ); + let item = ResponseItem::CustomToolCall { + id: None, + status: None, + call_id: "call-1".to_string(), + name: "shell".to_string(), + input: "{}".to_string(), + }; + + let call = ToolRouter::build_tool_call(session.as_ref(), item.clone()) + .await + .expect("build tool call") + .expect("tool call present"); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); + let err = router + .dispatch_tool_call( + Arc::clone(&session), + Arc::clone(&turn_context), + tracker, + call, + ToolCallSource::Direct, + ) + .await + .expect_err("expected fatal error"); + + match err { + FunctionCallError::Fatal(message) => { + assert_eq!(message, "tool shell invoked with incompatible payload"); + } + other => panic!("expected FunctionCallError::Fatal, got {other:?}"), + } + } + + async fn sample_rollout( + session: &Session, + _turn_context: &TurnContext, + ) -> (Vec, Vec) { + let mut rollout_items = Vec::new(); + let mut live_history = ContextManager::new(); + + // Use the same turn_context source as record_initial_history so model_info (and thus + // personality_spec) matches reconstruction. + let reconstruction_turn = session.new_default_turn().await; + let mut initial_context = session + .build_initial_context(reconstruction_turn.as_ref()) + .await; + // Ensure personality_spec is present when Personality is enabled, so expected matches + // what reconstruction produces (build_initial_context may omit it when baked into model). + if !initial_context.iter().any(|m| { + matches!(m, ResponseItem::Message { role, content, .. } + if role == "developer" + && content.iter().any(|c| { + matches!(c, ContentItem::InputText { text } if text.contains("")) + })) + }) + && let Some(p) = reconstruction_turn.personality + && session.features.enabled(Feature::Personality) + && let Some(personality_message) = reconstruction_turn + .model_info + .model_messages + .as_ref() + .and_then(|m| m.get_personality_message(Some(p)).filter(|s| !s.is_empty())) + { + let msg = + DeveloperInstructions::personality_spec_message(personality_message).into(); + let insert_at = initial_context + .iter() + .position(|m| matches!(m, ResponseItem::Message { role, .. } if role == "developer")) + .map(|i| i + 1) + .unwrap_or(0); + initial_context.insert(insert_at, msg); + } + for item in &initial_context { + rollout_items.push(RolloutItem::ResponseItem(item.clone())); + } + live_history.record_items( + initial_context.iter(), + reconstruction_turn.truncation_policy, + ); + + let user1 = ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "first user".to_string(), + }], + end_turn: None, + phase: None, + }; + live_history.record_items( + std::iter::once(&user1), + reconstruction_turn.truncation_policy, + ); + rollout_items.push(RolloutItem::ResponseItem(user1)); + + let assistant1 = ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: "assistant reply one".to_string(), + }], + end_turn: None, + phase: None, + }; + live_history.record_items( + std::iter::once(&assistant1), + reconstruction_turn.truncation_policy, + ); + rollout_items.push(RolloutItem::ResponseItem(assistant1)); + + let summary1 = "summary one"; + let snapshot1 = live_history + .clone() + .for_prompt(&reconstruction_turn.model_info.input_modalities); + let user_messages1 = collect_user_messages(&snapshot1); + let rebuilt1 = compact::build_compacted_history(Vec::new(), &user_messages1, summary1); + live_history.replace(rebuilt1); + rollout_items.push(RolloutItem::Compacted(CompactedItem { + message: summary1.to_string(), + replacement_history: None, + })); + + let user2 = ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "second user".to_string(), + }], + end_turn: None, + phase: None, + }; + live_history.record_items( + std::iter::once(&user2), + reconstruction_turn.truncation_policy, + ); + rollout_items.push(RolloutItem::ResponseItem(user2)); + + let assistant2 = ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: "assistant reply two".to_string(), + }], + end_turn: None, + phase: None, + }; + live_history.record_items( + std::iter::once(&assistant2), + reconstruction_turn.truncation_policy, + ); + rollout_items.push(RolloutItem::ResponseItem(assistant2)); + + let summary2 = "summary two"; + let snapshot2 = live_history + .clone() + .for_prompt(&reconstruction_turn.model_info.input_modalities); + let user_messages2 = collect_user_messages(&snapshot2); + let rebuilt2 = compact::build_compacted_history(Vec::new(), &user_messages2, summary2); + live_history.replace(rebuilt2); + rollout_items.push(RolloutItem::Compacted(CompactedItem { + message: summary2.to_string(), + replacement_history: None, + })); + + let user3 = ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "third user".to_string(), + }], + end_turn: None, + phase: None, + }; + live_history.record_items( + std::iter::once(&user3), + reconstruction_turn.truncation_policy, + ); + rollout_items.push(RolloutItem::ResponseItem(user3)); + + let assistant3 = ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: "assistant reply three".to_string(), + }], + end_turn: None, + phase: None, + }; + live_history.record_items( + std::iter::once(&assistant3), + reconstruction_turn.truncation_policy, + ); + rollout_items.push(RolloutItem::ResponseItem(assistant3)); + + ( + rollout_items, + live_history.for_prompt(&reconstruction_turn.model_info.input_modalities), + ) + } + + #[tokio::test] + async fn rejects_escalated_permissions_when_policy_not_on_request() { + use crate::exec::ExecParams; + use crate::protocol::AskForApproval; + use crate::protocol::SandboxPolicy; + use crate::sandboxing::SandboxPermissions; + use crate::turn_diff_tracker::TurnDiffTracker; + use std::collections::HashMap; + + let (session, mut turn_context_raw) = make_session_and_context().await; + // Ensure policy is NOT OnRequest so the early rejection path triggers + turn_context_raw + .approval_policy + .set(AskForApproval::OnFailure) + .expect("test setup should allow updating approval policy"); + let session = Arc::new(session); + let mut turn_context = Arc::new(turn_context_raw); + + let timeout_ms = 1000; + let sandbox_permissions = SandboxPermissions::RequireEscalated; + let params = ExecParams { + command: if cfg!(windows) { + vec![ + "cmd.exe".to_string(), + "/C".to_string(), + "echo hi".to_string(), + ] + } else { + vec![ + "/bin/sh".to_string(), + "-c".to_string(), + "echo hi".to_string(), + ] + }, + cwd: turn_context.cwd.clone(), + expiration: timeout_ms.into(), + env: HashMap::new(), + network: None, + sandbox_permissions, + windows_sandbox_level: turn_context.windows_sandbox_level, + justification: Some("test".to_string()), + arg0: None, + }; + + let params2 = ExecParams { + sandbox_permissions: SandboxPermissions::UseDefault, + command: params.command.clone(), + cwd: params.cwd.clone(), + expiration: timeout_ms.into(), + env: HashMap::new(), + network: None, + windows_sandbox_level: turn_context.windows_sandbox_level, + justification: params.justification.clone(), + arg0: None, + }; + + let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); + + let tool_name = "shell"; + let call_id = "test-call".to_string(); + + let handler = ShellHandler; + let resp = handler + .handle(ToolInvocation { + session: Arc::clone(&session), + turn: Arc::clone(&turn_context), + tracker: Arc::clone(&turn_diff_tracker), + call_id, + tool_name: tool_name.to_string(), + payload: ToolPayload::Function { + arguments: serde_json::json!({ + "command": params.command.clone(), + "workdir": Some(turn_context.cwd.to_string_lossy().to_string()), + "timeout_ms": params.expiration.timeout_ms(), + "sandbox_permissions": params.sandbox_permissions, + "justification": params.justification.clone(), + }) + .to_string(), + }, + }) + .await; + + let Err(FunctionCallError::RespondToModel(output)) = resp else { + panic!("expected error result"); + }; + + let expected = format!( + "approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}", + policy = turn_context.approval_policy.value() + ); + + pretty_assertions::assert_eq!(output, expected); + + // Now retry the same command WITHOUT escalated permissions; should succeed. + // Force DangerFullAccess to avoid platform sandbox dependencies in tests. + Arc::get_mut(&mut turn_context) + .expect("unique turn context Arc") + .sandbox_policy + .set(SandboxPolicy::DangerFullAccess) + .expect("test setup should allow updating sandbox policy"); + + let resp2 = handler + .handle(ToolInvocation { + session: Arc::clone(&session), + turn: Arc::clone(&turn_context), + tracker: Arc::clone(&turn_diff_tracker), + call_id: "test-call-2".to_string(), + tool_name: tool_name.to_string(), + payload: ToolPayload::Function { + arguments: serde_json::json!({ + "command": params2.command.clone(), + "workdir": Some(turn_context.cwd.to_string_lossy().to_string()), + "timeout_ms": params2.expiration.timeout_ms(), + "sandbox_permissions": params2.sandbox_permissions, + "justification": params2.justification.clone(), + }) + .to_string(), + }, + }) + .await; + + let output = match resp2.expect("expected Ok result") { + ToolOutput::Function { + body: FunctionCallOutputBody::Text(content), + .. + } => content, + _ => panic!("unexpected tool output"), + }; + + #[derive(Deserialize, PartialEq, Eq, Debug)] + struct ResponseExecMetadata { + exit_code: i32, + } + + #[derive(Deserialize)] + struct ResponseExecOutput { + output: String, + metadata: ResponseExecMetadata, + } + + let exec_output: ResponseExecOutput = + serde_json::from_str(&output).expect("valid exec output json"); + + pretty_assertions::assert_eq!(exec_output.metadata, ResponseExecMetadata { exit_code: 0 }); + assert!(exec_output.output.contains("hi")); + } + #[tokio::test] + async fn unified_exec_rejects_escalated_permissions_when_policy_not_on_request() { + use crate::protocol::AskForApproval; + use crate::sandboxing::SandboxPermissions; + use crate::turn_diff_tracker::TurnDiffTracker; + + let (session, mut turn_context_raw) = make_session_and_context().await; + turn_context_raw + .approval_policy + .set(AskForApproval::OnFailure) + .expect("test setup should allow updating approval policy"); + let session = Arc::new(session); + let turn_context = Arc::new(turn_context_raw); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); + + let handler = UnifiedExecHandler; + let resp = handler + .handle(ToolInvocation { + session: Arc::clone(&session), + turn: Arc::clone(&turn_context), + tracker: Arc::clone(&tracker), + call_id: "exec-call".to_string(), + tool_name: "exec_command".to_string(), + payload: ToolPayload::Function { + arguments: serde_json::json!({ + "cmd": "echo hi", + "sandbox_permissions": SandboxPermissions::RequireEscalated, + "justification": "need unsandboxed execution", + }) + .to_string(), + }, + }) + .await; + + let Err(FunctionCallError::RespondToModel(output)) = resp else { + panic!("expected error result"); + }; + + let expected = format!( + "approval policy is {policy:?}; reject command — you cannot ask for escalated permissions if the approval policy is {policy:?}", + policy = turn_context.approval_policy.value() + ); + + pretty_assertions::assert_eq!(output, expected); + } diff --git a/codex-rs/core/src/config/managed_features.rs b/codex-rs/core/src/config/managed_features.rs index 13510b0d853..4e45dedf91a 100644 --- a/codex-rs/core/src/config/managed_features.rs +++ b/codex-rs/core/src/config/managed_features.rs @@ -218,14 +218,6 @@ fn explicit_feature_settings_in_config(cfg: &ConfigToml) -> Vec<(String, Feature enabled, )); } - if let Some(enabled) = cfg.tools.as_ref().and_then(|tools| tools.web_search) { - explicit_settings.push(( - "tools.web_search".to_string(), - Feature::WebSearchRequest, - enabled, - )); - } - for (profile_name, profile) in &cfg.profiles { if let Some(features) = profile.features.as_ref() { for (key, enabled) in &features.entries { @@ -259,13 +251,6 @@ fn explicit_feature_settings_in_config(cfg: &ConfigToml) -> Vec<(String, Feature enabled, )); } - if let Some(enabled) = profile.tools_web_search { - explicit_settings.push(( - format!("profiles.{profile_name}.tools_web_search"), - Feature::WebSearchRequest, - enabled, - )); - } } explicit_settings diff --git a/codex-rs/core/src/config/mod.rs b/codex-rs/core/src/config/mod.rs index 6ea0cc1dc64..9294a9bd3d7 100644 --- a/codex-rs/core/src/config/mod.rs +++ b/codex-rs/core/src/config/mod.rs @@ -68,9 +68,9 @@ use codex_protocol::config_types::ServiceTier; use codex_protocol::config_types::TrustLevel; use codex_protocol::config_types::Verbosity; use codex_protocol::config_types::WebSearchConfig; -use codex_protocol::config_types::WebSearchFilters; +use codex_protocol::config_types::WebSearchLocation; use codex_protocol::config_types::WebSearchMode; -use codex_protocol::config_types::WebSearchUserLocation; +use codex_protocol::config_types::WebSearchToolConfig; use codex_protocol::config_types::WindowsSandboxLevel; use codex_protocol::models::MacOsSeatbeltProfileExtensions; use codex_protocol::openai_models::ModelsResponse; @@ -1222,9 +1222,6 @@ pub struct ConfigToml { /// Controls the web search tool mode: disabled, cached, or live. pub web_search: Option, - /// Optional structured configuration for the web search tool. - pub web_search_config: Option, - /// Nested tools section for feature toggles pub tools: Option, @@ -1363,8 +1360,8 @@ pub struct RealtimeAudioToml { #[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq, JsonSchema)] #[schemars(deny_unknown_fields)] pub struct ToolsToml { - #[serde(default, alias = "web_search_request")] - pub web_search: Option, + #[serde(default)] + pub web_search: Option, /// Enable the `view_image` tool that lets the agent attach local images. #[serde(default)] @@ -1650,53 +1647,52 @@ fn resolve_web_search_config( config_toml: &ConfigToml, config_profile: &ConfigProfile, ) -> Option { - let base = config_toml.web_search_config.as_ref(); - let profile = config_profile.web_search_config.as_ref(); + let base = config_toml + .tools + .as_ref() + .and_then(|tools| tools.web_search.as_ref()); + let profile = config_profile + .tools + .as_ref() + .and_then(|tools| tools.web_search.as_ref()); match (base, profile) { (None, None) => None, - (Some(base), None) => Some(base.clone()), - (None, Some(profile)) => Some(profile.clone()), - (Some(base), Some(profile)) => Some(WebSearchConfig { - filters: match (base.filters.as_ref(), profile.filters.as_ref()) { - (None, None) => None, - (Some(base_filters), None) => Some(base_filters.clone()), - (None, Some(profile_filters)) => Some(profile_filters.clone()), - (Some(base_filters), Some(profile_filters)) => Some(WebSearchFilters { - allowed_domains: profile_filters - .allowed_domains - .clone() - .or_else(|| base_filters.allowed_domains.clone()), - }), - }, - user_location: match (base.user_location.as_ref(), profile.user_location.as_ref()) { - (None, None) => None, - (Some(base_user_location), None) => Some(base_user_location.clone()), - (None, Some(profile_user_location)) => Some(profile_user_location.clone()), - (Some(base_user_location), Some(profile_user_location)) => { - Some(WebSearchUserLocation { - r#type: profile_user_location.r#type, - country: profile_user_location + (Some(base), None) => Some(base.clone().into()), + (None, Some(profile)) => Some(profile.clone().into()), + (Some(base), Some(profile)) => Some( + WebSearchToolConfig { + context_size: profile.context_size.or(base.context_size), + allowed_domains: profile + .allowed_domains + .clone() + .or_else(|| base.allowed_domains.clone()), + location: match (base.location.as_ref(), profile.location.as_ref()) { + (None, None) => None, + (Some(base_location), None) => Some(base_location.clone()), + (None, Some(profile_location)) => Some(profile_location.clone()), + (Some(base_location), Some(profile_location)) => Some(WebSearchLocation { + country: profile_location .country .clone() - .or_else(|| base_user_location.country.clone()), - region: profile_user_location + .or_else(|| base_location.country.clone()), + region: profile_location .region .clone() - .or_else(|| base_user_location.region.clone()), - city: profile_user_location + .or_else(|| base_location.region.clone()), + city: profile_location .city .clone() - .or_else(|| base_user_location.city.clone()), - timezone: profile_user_location + .or_else(|| base_location.city.clone()), + timezone: profile_location .timezone .clone() - .or_else(|| base_user_location.timezone.clone()), - }) - } - }, - search_context_size: profile.search_context_size.or(base.search_context_size), - }), + .or_else(|| base_location.timezone.clone()), + }), + }, + } + .into(), + ), } } @@ -1900,10 +1896,6 @@ impl Config { let web_search_mode = resolve_web_search_mode(&cfg, &config_profile, &features) .unwrap_or(WebSearchMode::Cached); let web_search_config = resolve_web_search_config(&cfg, &config_profile); - // TODO(dylan): We should be able to leverage ConfigLayerStack so that - // we can reliably check this at every config level. - let did_user_set_custom_approval_policy_or_sandbox_mode = - approval_policy_was_explicit || sandbox_mode_was_explicit; let mut model_providers = built_in_model_providers(); // Merge user-defined providers into the built-in list. diff --git a/codex-rs/core/src/config/profile.rs b/codex-rs/core/src/config/profile.rs index f2867b25391..3ee213513d5 100644 --- a/codex-rs/core/src/config/profile.rs +++ b/codex-rs/core/src/config/profile.rs @@ -3,6 +3,7 @@ use schemars::JsonSchema; use serde::Deserialize; use serde::Serialize; +use crate::config::ToolsToml; use crate::config::types::Personality; use crate::config::types::WindowsToml; use crate::protocol::AskForApproval; @@ -10,7 +11,6 @@ use codex_protocol::config_types::ReasoningSummary; use codex_protocol::config_types::SandboxMode; use codex_protocol::config_types::ServiceTier; use codex_protocol::config_types::Verbosity; -use codex_protocol::config_types::WebSearchConfig; use codex_protocol::config_types::WebSearchMode; use codex_protocol::openai_models::ReasoningEffort; @@ -48,10 +48,9 @@ pub struct ConfigProfile { pub include_apply_patch_tool: Option, pub experimental_use_unified_exec_tool: Option, pub experimental_use_freeform_apply_patch: Option, - pub tools_web_search: Option, pub tools_view_image: Option, + pub tools: Option, pub web_search: Option, - pub web_search_config: Option, pub analytics: Option, #[serde(default)] pub windows: Option, @@ -72,6 +71,7 @@ impl From for codex_app_server_protocol::Profile { model_reasoning_effort: config_profile.model_reasoning_effort, model_reasoning_summary: config_profile.model_reasoning_summary, model_verbosity: config_profile.model_verbosity, + tools: config_profile.tools.map(Into::into), chatgpt_base_url: config_profile.chatgpt_base_url, } } diff --git a/codex-rs/core/src/features.rs b/codex-rs/core/src/features.rs index 7defc571ffb..29e6de281e5 100644 --- a/codex-rs/core/src/features.rs +++ b/codex-rs/core/src/features.rs @@ -213,10 +213,17 @@ impl FeatureOverrides { fn apply(self, features: &mut Features) { LegacyFeatureToggles { include_apply_patch_tool: self.include_apply_patch_tool, - tools_web_search: self.web_search_request, ..Default::default() } .apply(features); + if let Some(enabled) = self.web_search_request { + if enabled { + features.enable(Feature::WebSearchRequest); + } else { + features.disable(Feature::WebSearchRequest); + } + features.record_legacy_usage("web_search_request", Feature::WebSearchRequest); + } } } @@ -342,7 +349,6 @@ impl Features { let base_legacy = LegacyFeatureToggles { experimental_use_freeform_apply_patch: cfg.experimental_use_freeform_apply_patch, experimental_use_unified_exec_tool: cfg.experimental_use_unified_exec_tool, - tools_web_search: cfg.tools.as_ref().and_then(|t| t.web_search), ..Default::default() }; base_legacy.apply(&mut features); @@ -357,7 +363,6 @@ impl Features { .experimental_use_freeform_apply_patch, experimental_use_unified_exec_tool: config_profile.experimental_use_unified_exec_tool, - tools_web_search: config_profile.tools_web_search, }; profile_legacy.apply(&mut features); if let Some(profile_features) = config_profile.features.as_ref() { @@ -388,7 +393,6 @@ fn legacy_usage_notice(alias: &str, feature: Feature) -> (String, Option Feature::WebSearchRequest | Feature::WebSearchCached => { let label = match alias { "web_search" => "[features].web_search", - "tools.web_search" => "[tools].web_search", "features.web_search_request" | "web_search_request" => { "[features].web_search_request" } diff --git a/codex-rs/core/src/features/legacy.rs b/codex-rs/core/src/features/legacy.rs index 45b3dfd5dca..b7aa30482a1 100644 --- a/codex-rs/core/src/features/legacy.rs +++ b/codex-rs/core/src/features/legacy.rs @@ -62,7 +62,6 @@ pub struct LegacyFeatureToggles { pub include_apply_patch_tool: Option, pub experimental_use_freeform_apply_patch: Option, pub experimental_use_unified_exec_tool: Option, - pub tools_web_search: Option, } impl LegacyFeatureToggles { @@ -85,12 +84,6 @@ impl LegacyFeatureToggles { self.experimental_use_unified_exec_tool, "experimental_use_unified_exec_tool", ); - set_if_some( - features, - Feature::WebSearchRequest, - self.tools_web_search, - "tools.web_search", - ); } } diff --git a/codex-rs/core/tests/suite/web_search.rs b/codex-rs/core/tests/suite/web_search.rs index 6df0338e920..c90ca91235a 100644 --- a/codex-rs/core/tests/suite/web_search.rs +++ b/codex-rs/core/tests/suite/web_search.rs @@ -9,6 +9,8 @@ use core_test_support::skip_if_no_network; use core_test_support::test_codex::test_codex; use pretty_assertions::assert_eq; use serde_json::Value; +use serde_json::json; +use std::sync::Arc; #[allow(clippy::expect_used)] fn find_web_search_tool(body: &Value) -> &Value { @@ -223,3 +225,61 @@ async fn web_search_mode_updates_between_turns_with_sandbox_policy() { "danger-full-access policy should default web_search to live" ); } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn web_search_tool_config_from_config_toml_is_forwarded_to_request() { + skip_if_no_network!(); + + let server = start_mock_server().await; + let sse = responses::sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_completed("resp-1"), + ]); + let resp_mock = responses::mount_sse_once(&server, sse).await; + + let home = Arc::new(tempfile::TempDir::new().expect("create codex home")); + std::fs::write( + home.path().join("config.toml"), + r#"web_search = "live" + +[tools.web_search] +context_size = "high" +allowed_domains = ["example.com"] +location = { country = "US", city = "New York", timezone = "America/New_York" } +"#, + ) + .expect("write config.toml"); + + let mut builder = test_codex().with_model("gpt-5-codex").with_home(home); + let test = builder + .build(&server) + .await + .expect("create test Codex conversation"); + + test.submit_turn_with_policy( + "hello configured web search", + SandboxPolicy::DangerFullAccess, + ) + .await + .expect("submit turn"); + + let body = resp_mock.single_request().body_json(); + let tool = find_web_search_tool(&body); + assert_eq!( + tool, + &json!({ + "type": "web_search", + "external_web_access": true, + "search_context_size": "high", + "filters": { + "allowed_domains": ["example.com"], + }, + "user_location": { + "type": "approximate", + "country": "US", + "city": "New York", + "timezone": "America/New_York", + }, + }) + ); +} diff --git a/codex-rs/protocol/src/config_types.rs b/codex-rs/protocol/src/config_types.rs index b4a964bbde7..cb4f934d5e2 100644 --- a/codex-rs/protocol/src/config_types.rs +++ b/codex-rs/protocol/src/config_types.rs @@ -122,6 +122,23 @@ pub enum WebSearchContextSize { High, } +#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq, Eq, JsonSchema, TS)] +#[schemars(deny_unknown_fields)] +pub struct WebSearchLocation { + pub country: Option, + pub region: Option, + pub city: Option, + pub timezone: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq, Eq, JsonSchema, TS)] +#[schemars(deny_unknown_fields)] +pub struct WebSearchToolConfig { + pub context_size: Option, + pub allowed_domains: Option>, + pub location: Option, +} + #[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq, Eq, JsonSchema, TS)] #[schemars(deny_unknown_fields)] pub struct WebSearchFilters { @@ -157,6 +174,32 @@ pub struct WebSearchConfig { pub search_context_size: Option, } +impl From for WebSearchUserLocation { + fn from(location: WebSearchLocation) -> Self { + Self { + r#type: WebSearchUserLocationType::Approximate, + country: location.country, + region: location.region, + city: location.city, + timezone: location.timezone, + } + } +} + +impl From for WebSearchConfig { + fn from(config: WebSearchToolConfig) -> Self { + Self { + filters: config + .allowed_domains + .map(|allowed_domains| WebSearchFilters { + allowed_domains: Some(allowed_domains), + }), + user_location: config.location.map(Into::into), + search_context_size: config.context_size, + } + } +} + #[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Display, JsonSchema, TS)] #[serde(rename_all = "lowercase")] #[strum(serialize_all = "lowercase")] From a092ceadf4c106fe64970418aacedb31c768690d Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Fri, 6 Mar 2026 14:12:11 -0500 Subject: [PATCH 3/7] Fix config_read_includes_tools test --- codex-rs/app-server/tests/suite/v2/config_rpc.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codex-rs/app-server/tests/suite/v2/config_rpc.rs b/codex-rs/app-server/tests/suite/v2/config_rpc.rs index 99bf0d6d2cb..1b73589a6bb 100644 --- a/codex-rs/app-server/tests/suite/v2/config_rpc.rs +++ b/codex-rs/app-server/tests/suite/v2/config_rpc.rs @@ -150,7 +150,7 @@ view_image = false ); assert_eq!( origins - .get("tools.web_search.allowed_domains") + .get("tools.web_search.allowed_domains.0") .expect("origin") .name, ConfigLayerSource::User { From bba8eac6d705667da9449718f556d377ca081cab Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Fri, 6 Mar 2026 15:11:58 -0500 Subject: [PATCH 4/7] Fix config RPC tools test --- codex-rs/core/src/config/config_tests.rs | 32 +++--------------------- 1 file changed, 4 insertions(+), 28 deletions(-) diff --git a/codex-rs/core/src/config/config_tests.rs b/codex-rs/core/src/config/config_tests.rs index 8fa95d6a29c..1e135c50dde 100644 --- a/codex-rs/core/src/config/config_tests.rs +++ b/codex-rs/core/src/config/config_tests.rs @@ -867,34 +867,6 @@ fn web_search_mode_for_turn_falls_back_when_live_is_disallowed() -> anyhow::Resu Ok(()) } -#[test] -fn profile_legacy_toggles_override_base() -> std::io::Result<()> { - let codex_home = TempDir::new()?; - let mut profiles = HashMap::new(); - profiles.insert( - "work".to_string(), - ConfigProfile { - tools_web_search: Some(false), - ..Default::default() - }, - ); - let cfg = ConfigToml { - profiles, - profile: Some("work".to_string()), - ..Default::default() - }; - - let config = Config::load_from_base_config_with_overrides( - cfg, - ConfigOverrides::default(), - codex_home.path().to_path_buf(), - )?; - - assert!(!config.features.enabled(Feature::WebSearchRequest)); - - Ok(()) -} - #[tokio::test] async fn project_profile_overrides_user_profile() -> std::io::Result<()> { let codex_home = TempDir::new()?; @@ -2712,6 +2684,7 @@ fn test_precedence_fixture_with_o3_profile() -> std::io::Result<()> { forced_login_method: None, include_apply_patch_tool: false, web_search_mode: Constrained::allow_any(WebSearchMode::Cached), + web_search_config: None, use_experimental_unified_exec_tool: !cfg!(windows), background_terminal_max_timeout: DEFAULT_MAX_BACKGROUND_TERMINAL_TIMEOUT_MS, ghost_snapshot: GhostSnapshotConfig::default(), @@ -2841,6 +2814,7 @@ fn test_precedence_fixture_with_gpt3_profile() -> std::io::Result<()> { forced_login_method: None, include_apply_patch_tool: false, web_search_mode: Constrained::allow_any(WebSearchMode::Cached), + web_search_config: None, use_experimental_unified_exec_tool: !cfg!(windows), background_terminal_max_timeout: DEFAULT_MAX_BACKGROUND_TERMINAL_TIMEOUT_MS, ghost_snapshot: GhostSnapshotConfig::default(), @@ -2968,6 +2942,7 @@ fn test_precedence_fixture_with_zdr_profile() -> std::io::Result<()> { forced_login_method: None, include_apply_patch_tool: false, web_search_mode: Constrained::allow_any(WebSearchMode::Cached), + web_search_config: None, use_experimental_unified_exec_tool: !cfg!(windows), background_terminal_max_timeout: DEFAULT_MAX_BACKGROUND_TERMINAL_TIMEOUT_MS, ghost_snapshot: GhostSnapshotConfig::default(), @@ -3081,6 +3056,7 @@ fn test_precedence_fixture_with_gpt5_profile() -> std::io::Result<()> { forced_login_method: None, include_apply_patch_tool: false, web_search_mode: Constrained::allow_any(WebSearchMode::Cached), + web_search_config: None, use_experimental_unified_exec_tool: !cfg!(windows), background_terminal_max_timeout: DEFAULT_MAX_BACKGROUND_TERMINAL_TIMEOUT_MS, ghost_snapshot: GhostSnapshotConfig::default(), From 6e76635d44d46d7e526cf4dce142807989efebba Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Fri, 6 Mar 2026 15:27:41 -0500 Subject: [PATCH 5/7] fix --- codex-rs/core/src/codex.rs | 3874 ------------------------------------ 1 file changed, 3874 deletions(-) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index d3d5db7926a..44d365233d6 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -6867,3877 +6867,3 @@ pub(crate) use tests::make_session_configuration_for_tests; #[cfg(test)] #[path = "codex_tests.rs"] mod tests; - - struct InstructionsTestCase { - slug: &'static str, - expects_apply_patch_instructions: bool, - } - - fn user_message(text: &str) -> ResponseItem { - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: text.to_string(), - }], - end_turn: None, - phase: None, - } - } - - fn assistant_message(text: &str) -> ResponseItem { - ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: text.to_string(), - }], - end_turn: None, - phase: None, - } - } - - fn skill_message(text: &str) -> ResponseItem { - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: text.to_string(), - }], - end_turn: None, - phase: None, - } - } - - fn developer_input_texts(items: &[ResponseItem]) -> Vec<&str> { - items - .iter() - .filter_map(|item| match item { - ResponseItem::Message { role, content, .. } if role == "developer" => { - Some(content.as_slice()) - } - _ => None, - }) - .flat_map(|content| content.iter()) - .filter_map(|item| match item { - ContentItem::InputText { text } => Some(text.as_str()), - _ => None, - }) - .collect() - } - - fn make_connector(id: &str, name: &str) -> AppInfo { - AppInfo { - id: id.to_string(), - name: name.to_string(), - description: None, - logo_url: None, - logo_url_dark: None, - distribution_channel: None, - branding: None, - app_metadata: None, - labels: None, - install_url: None, - is_accessible: true, - is_enabled: true, - plugin_display_names: Vec::new(), - } - } - - #[test] - fn assistant_message_stream_parsers_can_be_seeded_from_output_item_added_text() { - let mut parsers = AssistantMessageStreamParsers::new(false); - let item_id = "msg-1"; - - let seeded = parsers.seed_item_text(item_id, "hello doc"); - let parsed = parsers.parse_delta(item_id, "1 world"); - let tail = parsers.finish_item(item_id); - - assert_eq!(seeded.visible_text, "hello "); - assert_eq!(seeded.citations, Vec::::new()); - assert_eq!(parsed.visible_text, " world"); - assert_eq!(parsed.citations, vec!["doc1".to_string()]); - assert_eq!(tail.visible_text, ""); - assert_eq!(tail.citations, Vec::::new()); - } - - #[test] - fn assistant_message_stream_parsers_seed_buffered_prefix_stays_out_of_finish_tail() { - let mut parsers = AssistantMessageStreamParsers::new(false); - let item_id = "msg-1"; - - let seeded = parsers.seed_item_text(item_id, "hello doc world"); - let tail = parsers.finish_item(item_id); - - assert_eq!(seeded.visible_text, "hello "); - assert_eq!(seeded.citations, Vec::::new()); - assert_eq!(parsed.visible_text, " world"); - assert_eq!(parsed.citations, vec!["doc".to_string()]); - assert_eq!(tail.visible_text, ""); - assert_eq!(tail.citations, Vec::::new()); - } - - #[test] - fn assistant_message_stream_parsers_seed_plan_parser_across_added_and_delta_boundaries() { - let mut parsers = AssistantMessageStreamParsers::new(true); - let item_id = "msg-1"; - - let seeded = parsers.seed_item_text(item_id, "Intro\n\n- step\n\nOutro"); - let tail = parsers.finish_item(item_id); - - assert_eq!(seeded.visible_text, "Intro\n"); - assert_eq!( - seeded.plan_segments, - vec![ProposedPlanSegment::Normal("Intro\n".to_string())] - ); - assert_eq!(parsed.visible_text, "Outro"); - assert_eq!( - parsed.plan_segments, - vec![ - ProposedPlanSegment::ProposedPlanStart, - ProposedPlanSegment::ProposedPlanDelta("- step\n".to_string()), - ProposedPlanSegment::ProposedPlanEnd, - ProposedPlanSegment::Normal("Outro".to_string()), - ] - ); - assert_eq!(tail.visible_text, ""); - assert!(tail.plan_segments.is_empty()); - } - - fn make_mcp_tool( - server_name: &str, - tool_name: &str, - connector_id: Option<&str>, - connector_name: Option<&str>, - ) -> ToolInfo { - ToolInfo { - server_name: server_name.to_string(), - tool_name: tool_name.to_string(), - tool: Tool { - name: tool_name.to_string().into(), - title: None, - description: Some(format!("Test tool: {tool_name}").into()), - input_schema: Arc::new(JsonObject::default()), - output_schema: None, - annotations: None, - execution: None, - icons: None, - meta: None, - }, - connector_id: connector_id.map(str::to_string), - connector_name: connector_name.map(str::to_string), - plugin_display_names: Vec::new(), - } - } - - fn function_call_rollout_item(name: &str, call_id: &str) -> RolloutItem { - RolloutItem::ResponseItem(ResponseItem::FunctionCall { - id: None, - name: name.to_string(), - arguments: "{}".to_string(), - call_id: call_id.to_string(), - }) - } - - fn function_call_output_rollout_item(call_id: &str, output: &str) -> RolloutItem { - RolloutItem::ResponseItem(ResponseItem::FunctionCallOutput { - call_id: call_id.to_string(), - output: FunctionCallOutputPayload::from_text(output.to_string()), - }) - } - - #[test] - fn validated_network_policy_amendment_host_allows_normalized_match() { - let amendment = NetworkPolicyAmendment { - host: "ExAmPlE.Com.:443".to_string(), - action: NetworkPolicyRuleAction::Allow, - }; - let context = NetworkApprovalContext { - host: "example.com".to_string(), - protocol: NetworkApprovalProtocol::Https, - }; - - let host = Session::validated_network_policy_amendment_host(&amendment, &context) - .expect("normalized hosts should match"); - - assert_eq!(host, "example.com"); - } - - #[test] - fn validated_network_policy_amendment_host_rejects_mismatch() { - let amendment = NetworkPolicyAmendment { - host: "evil.example.com".to_string(), - action: NetworkPolicyRuleAction::Deny, - }; - let context = NetworkApprovalContext { - host: "api.example.com".to_string(), - protocol: NetworkApprovalProtocol::Https, - }; - - let err = Session::validated_network_policy_amendment_host(&amendment, &context) - .expect_err("mismatched hosts should be rejected"); - - let message = err.to_string(); - assert!(message.contains("does not match approved host")); - } - - #[tokio::test] - async fn get_base_instructions_no_user_content() { - let prompt_with_apply_patch_instructions = - include_str!("../prompt_with_apply_patch_instructions.md"); - let models_response: ModelsResponse = - serde_json::from_str(include_str!("../models.json")).expect("valid models.json"); - let model_info_for_slug = |slug: &str, config: &Config| { - let model = models_response - .models - .iter() - .find(|candidate| candidate.slug == slug) - .cloned() - .unwrap_or_else(|| panic!("model slug {slug} is missing from models.json")); - model_info::with_config_overrides(model, config) - }; - let test_cases = vec![ - InstructionsTestCase { - slug: "gpt-5", - expects_apply_patch_instructions: false, - }, - InstructionsTestCase { - slug: "gpt-5.1", - expects_apply_patch_instructions: false, - }, - InstructionsTestCase { - slug: "gpt-5.1-codex", - expects_apply_patch_instructions: false, - }, - InstructionsTestCase { - slug: "gpt-5.1-codex-max", - expects_apply_patch_instructions: false, - }, - ]; - - let (session, _turn_context) = make_session_and_context().await; - let config = test_config(); - - for test_case in test_cases { - let model_info = model_info_for_slug(test_case.slug, &config); - if test_case.expects_apply_patch_instructions { - assert_eq!( - model_info.base_instructions.as_str(), - prompt_with_apply_patch_instructions - ); - } - - { - let mut state = session.state.lock().await; - state.session_configuration.base_instructions = - model_info.base_instructions.clone(); - } - - let base_instructions = session.get_base_instructions().await; - assert_eq!(base_instructions.text, model_info.base_instructions); - } - } - - #[tokio::test] - async fn reload_user_config_layer_updates_effective_apps_config() { - let (session, _turn_context) = make_session_and_context().await; - let codex_home = session.codex_home().await; - std::fs::create_dir_all(&codex_home).expect("create codex home"); - let config_toml_path = codex_home.join(CONFIG_TOML_FILE); - std::fs::write( - &config_toml_path, - "[apps.calendar]\nenabled = false\ndestructive_enabled = false\n", - ) - .expect("write user config"); - - session.reload_user_config_layer().await; - - let config = session.get_config().await; - let apps_toml = config - .config_layer_stack - .effective_config() - .as_table() - .and_then(|table| table.get("apps")) - .cloned() - .expect("apps table"); - let apps = crate::config::types::AppsConfigToml::deserialize(apps_toml) - .expect("deserialize apps config"); - let app = apps - .apps - .get("calendar") - .expect("calendar app config exists"); - - assert!(!app.enabled); - assert_eq!(app.destructive_enabled, Some(false)); - } - - #[test] - fn filter_connectors_for_input_skips_duplicate_slug_mentions() { - let connectors = vec![ - make_connector("one", "Foo Bar"), - make_connector("two", "Foo-Bar"), - ]; - let input = vec![user_message("use $foo-bar")]; - let explicitly_enabled_connectors = HashSet::new(); - let skill_name_counts_lower = HashMap::new(); - - let selected = filter_connectors_for_input( - &connectors, - &input, - &explicitly_enabled_connectors, - &skill_name_counts_lower, - ); - - assert_eq!(selected, Vec::new()); - } - - #[test] - fn filter_connectors_for_input_skips_when_skill_name_conflicts() { - let connectors = vec![make_connector("one", "Todoist")]; - let input = vec![user_message("use $todoist")]; - let explicitly_enabled_connectors = HashSet::new(); - let skill_name_counts_lower = HashMap::from([("todoist".to_string(), 1)]); - - let selected = filter_connectors_for_input( - &connectors, - &input, - &explicitly_enabled_connectors, - &skill_name_counts_lower, - ); - - assert_eq!(selected, Vec::new()); - } - - #[test] - fn filter_connectors_for_input_skips_disabled_connectors() { - let mut connector = make_connector("calendar", "Calendar"); - connector.is_enabled = false; - let input = vec![user_message("use $calendar")]; - let explicitly_enabled_connectors = HashSet::new(); - let selected = filter_connectors_for_input( - &[connector], - &input, - &explicitly_enabled_connectors, - &HashMap::new(), - ); - - assert_eq!(selected, Vec::new()); - } - - #[test] - fn collect_explicit_app_ids_from_skill_items_includes_linked_mentions() { - let connectors = vec![make_connector("calendar", "Calendar")]; - let skill_items = vec![skill_message( - "\ndemo\n/tmp/skills/demo/SKILL.md\nuse [$calendar](app://calendar)\n", - )]; - - let connector_ids = - collect_explicit_app_ids_from_skill_items(&skill_items, &connectors, &HashMap::new()); - - assert_eq!(connector_ids, HashSet::from(["calendar".to_string()])); - } - - #[test] - fn collect_explicit_app_ids_from_skill_items_resolves_unambiguous_plain_mentions() { - let connectors = vec![make_connector("calendar", "Calendar")]; - let skill_items = vec![skill_message( - "\ndemo\n/tmp/skills/demo/SKILL.md\nuse $calendar\n", - )]; - - let connector_ids = - collect_explicit_app_ids_from_skill_items(&skill_items, &connectors, &HashMap::new()); - - assert_eq!(connector_ids, HashSet::from(["calendar".to_string()])); - } - - #[test] - fn collect_explicit_app_ids_from_skill_items_skips_plain_mentions_with_skill_conflicts() { - let connectors = vec![make_connector("calendar", "Calendar")]; - let skill_items = vec![skill_message( - "\ndemo\n/tmp/skills/demo/SKILL.md\nuse $calendar\n", - )]; - let skill_name_counts_lower = HashMap::from([("calendar".to_string(), 1)]); - - let connector_ids = collect_explicit_app_ids_from_skill_items( - &skill_items, - &connectors, - &skill_name_counts_lower, - ); - - assert_eq!(connector_ids, HashSet::::new()); - } - - #[test] - fn non_app_mcp_tools_remain_visible_without_search_selection() { - let mcp_tools = HashMap::from([ - ( - "mcp__codex_apps__calendar_create_event".to_string(), - make_mcp_tool( - CODEX_APPS_MCP_SERVER_NAME, - "calendar_create_event", - Some("calendar"), - Some("Calendar"), - ), - ), - ( - "mcp__rmcp__echo".to_string(), - make_mcp_tool("rmcp", "echo", None, None), - ), - ]); - - let mut selected_mcp_tools = mcp_tools - .iter() - .filter(|(_, tool)| tool.server_name != CODEX_APPS_MCP_SERVER_NAME) - .map(|(name, tool)| (name.clone(), tool.clone())) - .collect::>(); - - let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools); - let explicitly_enabled_connectors = HashSet::new(); - let connectors = filter_connectors_for_input( - &connectors, - &[user_message("run echo")], - &explicitly_enabled_connectors, - &HashMap::new(), - ); - let apps_mcp_tools = filter_codex_apps_mcp_tools_only(&mcp_tools, &connectors); - selected_mcp_tools.extend(apps_mcp_tools); - - let mut tool_names: Vec = selected_mcp_tools.into_keys().collect(); - tool_names.sort(); - assert_eq!(tool_names, vec!["mcp__rmcp__echo".to_string()]); - } - - #[test] - fn search_tool_selection_keeps_codex_apps_tools_without_mentions() { - let selected_tool_names = vec![ - "mcp__codex_apps__calendar_create_event".to_string(), - "mcp__rmcp__echo".to_string(), - ]; - let mcp_tools = HashMap::from([ - ( - "mcp__codex_apps__calendar_create_event".to_string(), - make_mcp_tool( - CODEX_APPS_MCP_SERVER_NAME, - "calendar_create_event", - Some("calendar"), - Some("Calendar"), - ), - ), - ( - "mcp__rmcp__echo".to_string(), - make_mcp_tool("rmcp", "echo", None, None), - ), - ]); - - let mut selected_mcp_tools = filter_mcp_tools_by_name(&mcp_tools, &selected_tool_names); - let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools); - let explicitly_enabled_connectors = HashSet::new(); - let connectors = filter_connectors_for_input( - &connectors, - &[user_message("run the selected tools")], - &explicitly_enabled_connectors, - &HashMap::new(), - ); - let apps_mcp_tools = filter_codex_apps_mcp_tools_only(&mcp_tools, &connectors); - selected_mcp_tools.extend(apps_mcp_tools); - - let mut tool_names: Vec = selected_mcp_tools.into_keys().collect(); - tool_names.sort(); - assert_eq!( - tool_names, - vec![ - "mcp__codex_apps__calendar_create_event".to_string(), - "mcp__rmcp__echo".to_string(), - ] - ); - } - - #[test] - fn apps_mentions_add_codex_apps_tools_to_search_selected_set() { - let selected_tool_names = vec!["mcp__rmcp__echo".to_string()]; - let mcp_tools = HashMap::from([ - ( - "mcp__codex_apps__calendar_create_event".to_string(), - make_mcp_tool( - CODEX_APPS_MCP_SERVER_NAME, - "calendar_create_event", - Some("calendar"), - Some("Calendar"), - ), - ), - ( - "mcp__rmcp__echo".to_string(), - make_mcp_tool("rmcp", "echo", None, None), - ), - ]); - - let mut selected_mcp_tools = filter_mcp_tools_by_name(&mcp_tools, &selected_tool_names); - let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools); - let explicitly_enabled_connectors = HashSet::new(); - let connectors = filter_connectors_for_input( - &connectors, - &[user_message("use $calendar and then echo the response")], - &explicitly_enabled_connectors, - &HashMap::new(), - ); - let apps_mcp_tools = filter_codex_apps_mcp_tools_only(&mcp_tools, &connectors); - selected_mcp_tools.extend(apps_mcp_tools); - - let mut tool_names: Vec = selected_mcp_tools.into_keys().collect(); - tool_names.sort(); - assert_eq!( - tool_names, - vec![ - "mcp__codex_apps__calendar_create_event".to_string(), - "mcp__rmcp__echo".to_string(), - ] - ); - } - - #[test] - fn extract_mcp_tool_selection_from_rollout_reads_search_tool_output() { - let rollout_items = vec![ - function_call_rollout_item(SEARCH_TOOL_BM25_TOOL_NAME, "search-1"), - function_call_output_rollout_item( - "search-1", - &json!({ - "active_selected_tools": [ - "mcp__codex_apps__calendar_create_event", - "mcp__codex_apps__calendar_list_events", - ], - }) - .to_string(), - ), - ]; - - let selected = Session::extract_mcp_tool_selection_from_rollout(&rollout_items); - assert_eq!( - selected, - Some(vec![ - "mcp__codex_apps__calendar_create_event".to_string(), - "mcp__codex_apps__calendar_list_events".to_string(), - ]) - ); - } - - #[test] - fn extract_mcp_tool_selection_from_rollout_latest_valid_payload_wins() { - let rollout_items = vec![ - function_call_rollout_item(SEARCH_TOOL_BM25_TOOL_NAME, "search-1"), - function_call_output_rollout_item( - "search-1", - &json!({ - "active_selected_tools": ["mcp__codex_apps__calendar_create_event"], - }) - .to_string(), - ), - function_call_rollout_item(SEARCH_TOOL_BM25_TOOL_NAME, "search-2"), - function_call_output_rollout_item( - "search-2", - &json!({ - "active_selected_tools": ["mcp__codex_apps__calendar_delete_event"], - }) - .to_string(), - ), - ]; - - let selected = Session::extract_mcp_tool_selection_from_rollout(&rollout_items); - assert_eq!( - selected, - Some(vec!["mcp__codex_apps__calendar_delete_event".to_string(),]) - ); - } - - #[test] - fn extract_mcp_tool_selection_from_rollout_ignores_non_search_and_malformed_payloads() { - let rollout_items = vec![ - function_call_rollout_item("shell", "shell-1"), - function_call_output_rollout_item( - "shell-1", - &json!({ - "active_selected_tools": ["mcp__codex_apps__should_be_ignored"], - }) - .to_string(), - ), - function_call_rollout_item(SEARCH_TOOL_BM25_TOOL_NAME, "search-1"), - function_call_output_rollout_item("search-1", "{not-json"), - function_call_output_rollout_item( - "unknown-search-call", - &json!({ - "active_selected_tools": ["mcp__codex_apps__also_ignored"], - }) - .to_string(), - ), - function_call_output_rollout_item( - "search-1", - &json!({ - "active_selected_tools": ["mcp__codex_apps__calendar_list_events"], - }) - .to_string(), - ), - ]; - - let selected = Session::extract_mcp_tool_selection_from_rollout(&rollout_items); - assert_eq!( - selected, - Some(vec!["mcp__codex_apps__calendar_list_events".to_string(),]) - ); - } - - #[test] - fn extract_mcp_tool_selection_from_rollout_returns_none_without_valid_search_output() { - let rollout_items = vec![function_call_rollout_item( - SEARCH_TOOL_BM25_TOOL_NAME, - "search-1", - )]; - let selected = Session::extract_mcp_tool_selection_from_rollout(&rollout_items); - assert_eq!(selected, None); - } - - #[tokio::test] - async fn reconstruct_history_matches_live_compactions() { - let (session, turn_context) = make_session_and_context().await; - let (rollout_items, expected) = sample_rollout(&session, &turn_context).await; - - let reconstruction_turn = session.new_default_turn().await; - let reconstructed = session - .reconstruct_history_from_rollout(reconstruction_turn.as_ref(), &rollout_items) - .await; - - assert_eq!(expected, reconstructed.history); - } - - #[tokio::test] - async fn reconstruct_history_uses_replacement_history_verbatim() { - let (session, turn_context) = make_session_and_context().await; - let summary_item = ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "summary".to_string(), - }], - end_turn: None, - phase: None, - }; - let replacement_history = vec![ - summary_item.clone(), - ResponseItem::Message { - id: None, - role: "developer".to_string(), - content: vec![ContentItem::InputText { - text: "stale developer instructions".to_string(), - }], - end_turn: None, - phase: None, - }, - ]; - let rollout_items = vec![RolloutItem::Compacted(CompactedItem { - message: String::new(), - replacement_history: Some(replacement_history.clone()), - })]; - - let reconstructed = session - .reconstruct_history_from_rollout(&turn_context, &rollout_items) - .await; - - assert_eq!(reconstructed.history, replacement_history); - } - - #[tokio::test] - async fn record_initial_history_reconstructs_resumed_transcript() { - let (session, turn_context) = make_session_and_context().await; - let (rollout_items, expected) = sample_rollout(&session, &turn_context).await; - - session - .record_initial_history(InitialHistory::Resumed(ResumedHistory { - conversation_id: ThreadId::default(), - history: rollout_items, - rollout_path: PathBuf::from("/tmp/resume.jsonl"), - })) - .await; - - let history = session.state.lock().await.clone_history(); - assert_eq!(expected, history.raw_items()); - } - - #[tokio::test] - async fn resumed_history_injects_initial_context_on_first_context_update_only() { - let (session, turn_context) = make_session_and_context().await; - let (rollout_items, mut expected) = sample_rollout(&session, &turn_context).await; - - session - .record_initial_history(InitialHistory::Resumed(ResumedHistory { - conversation_id: ThreadId::default(), - history: rollout_items, - rollout_path: PathBuf::from("/tmp/resume.jsonl"), - })) - .await; - - let history_before_seed = session.state.lock().await.clone_history(); - assert_eq!(expected, history_before_seed.raw_items()); - - session - .record_context_updates_and_set_reference_context_item(&turn_context) - .await; - expected.extend(session.build_initial_context(&turn_context).await); - let history_after_seed = session.clone_history().await; - assert_eq!(expected, history_after_seed.raw_items()); - - session - .record_context_updates_and_set_reference_context_item(&turn_context) - .await; - let history_after_second_seed = session.clone_history().await; - assert_eq!( - history_after_seed.raw_items(), - history_after_second_seed.raw_items() - ); - } - - #[tokio::test] - async fn record_initial_history_seeds_token_info_from_rollout() { - let (session, turn_context) = make_session_and_context().await; - let (mut rollout_items, _expected) = sample_rollout(&session, &turn_context).await; - - let info1 = TokenUsageInfo { - total_token_usage: TokenUsage { - input_tokens: 10, - cached_input_tokens: 0, - output_tokens: 20, - reasoning_output_tokens: 0, - total_tokens: 30, - }, - last_token_usage: TokenUsage { - input_tokens: 3, - cached_input_tokens: 0, - output_tokens: 4, - reasoning_output_tokens: 0, - total_tokens: 7, - }, - model_context_window: Some(1_000), - }; - let info2 = TokenUsageInfo { - total_token_usage: TokenUsage { - input_tokens: 100, - cached_input_tokens: 50, - output_tokens: 200, - reasoning_output_tokens: 25, - total_tokens: 375, - }, - last_token_usage: TokenUsage { - input_tokens: 10, - cached_input_tokens: 0, - output_tokens: 20, - reasoning_output_tokens: 5, - total_tokens: 35, - }, - model_context_window: Some(2_000), - }; - - rollout_items.push(RolloutItem::EventMsg(EventMsg::TokenCount( - TokenCountEvent { - info: Some(info1), - rate_limits: None, - }, - ))); - rollout_items.push(RolloutItem::EventMsg(EventMsg::TokenCount( - TokenCountEvent { - info: None, - rate_limits: None, - }, - ))); - rollout_items.push(RolloutItem::EventMsg(EventMsg::TokenCount( - TokenCountEvent { - info: Some(info2.clone()), - rate_limits: None, - }, - ))); - rollout_items.push(RolloutItem::EventMsg(EventMsg::TokenCount( - TokenCountEvent { - info: None, - rate_limits: None, - }, - ))); - - session - .record_initial_history(InitialHistory::Resumed(ResumedHistory { - conversation_id: ThreadId::default(), - history: rollout_items, - rollout_path: PathBuf::from("/tmp/resume.jsonl"), - })) - .await; - - let actual = session.state.lock().await.token_info(); - assert_eq!(actual, Some(info2)); - } - - #[tokio::test] - async fn recompute_token_usage_uses_session_base_instructions() { - let (session, turn_context) = make_session_and_context().await; - - let override_instructions = "SESSION_OVERRIDE_INSTRUCTIONS_ONLY".repeat(120); - { - let mut state = session.state.lock().await; - state.session_configuration.base_instructions = override_instructions.clone(); - } - - let item = user_message("hello"); - session - .record_into_history(std::slice::from_ref(&item), &turn_context) - .await; - - let history = session.clone_history().await; - let session_base_instructions = BaseInstructions { - text: override_instructions, - }; - let expected_tokens = history - .estimate_token_count_with_base_instructions(&session_base_instructions) - .expect("estimate with session base instructions"); - let model_estimated_tokens = history - .estimate_token_count(&turn_context) - .expect("estimate with model instructions"); - assert_ne!(expected_tokens, model_estimated_tokens); - - session.recompute_token_usage(&turn_context).await; - - let actual_tokens = session - .state - .lock() - .await - .token_info() - .expect("token info") - .last_token_usage - .total_tokens; - assert_eq!(actual_tokens, expected_tokens.max(0)); - } - - #[tokio::test] - async fn recompute_token_usage_updates_model_context_window() { - let (session, mut turn_context) = make_session_and_context().await; - - { - let mut state = session.state.lock().await; - state.set_token_info(Some(TokenUsageInfo { - total_token_usage: TokenUsage::default(), - last_token_usage: TokenUsage::default(), - model_context_window: Some(258_400), - })); - } - - turn_context.model_info.context_window = Some(128_000); - turn_context.model_info.effective_context_window_percent = 100; - - session.recompute_token_usage(&turn_context).await; - - let actual = session.state.lock().await.token_info().expect("token info"); - assert_eq!(actual.model_context_window, Some(128_000)); - } - - #[tokio::test] - async fn record_initial_history_reconstructs_forked_transcript() { - let (session, turn_context) = make_session_and_context().await; - let (rollout_items, mut expected) = sample_rollout(&session, &turn_context).await; - - session - .record_initial_history(InitialHistory::Forked(rollout_items)) - .await; - - let reconstruction_turn = session.new_default_turn().await; - expected.extend( - session - .build_initial_context(reconstruction_turn.as_ref()) - .await, - ); - let history = session.state.lock().await.clone_history(); - assert_eq!(expected, history.raw_items()); - } - - #[tokio::test] - async fn record_initial_history_forked_hydrates_previous_turn_settings() { - let (session, turn_context) = make_session_and_context().await; - let previous_model = "forked-rollout-model"; - let previous_context_item = TurnContextItem { - turn_id: Some(turn_context.sub_id.clone()), - trace_id: turn_context.trace_id.clone(), - cwd: turn_context.cwd.clone(), - current_date: turn_context.current_date.clone(), - timezone: turn_context.timezone.clone(), - approval_policy: turn_context.approval_policy.value(), - sandbox_policy: turn_context.sandbox_policy.get().clone(), - network: None, - model: previous_model.to_string(), - personality: turn_context.personality, - collaboration_mode: Some(turn_context.collaboration_mode.clone()), - realtime_active: Some(turn_context.realtime_active), - effort: turn_context.reasoning_effort, - summary: turn_context.reasoning_summary, - user_instructions: None, - developer_instructions: None, - final_output_json_schema: None, - truncation_policy: Some(turn_context.truncation_policy.into()), - }; - let turn_id = previous_context_item - .turn_id - .clone() - .expect("turn context should have turn_id"); - let rollout_items = vec![ - RolloutItem::EventMsg(EventMsg::TurnStarted( - codex_protocol::protocol::TurnStartedEvent { - turn_id: turn_id.clone(), - model_context_window: Some(128_000), - collaboration_mode_kind: ModeKind::Default, - }, - )), - RolloutItem::EventMsg(EventMsg::UserMessage( - codex_protocol::protocol::UserMessageEvent { - message: "forked seed".to_string(), - images: None, - local_images: Vec::new(), - text_elements: Vec::new(), - }, - )), - RolloutItem::TurnContext(previous_context_item), - RolloutItem::EventMsg(EventMsg::TurnComplete( - codex_protocol::protocol::TurnCompleteEvent { - turn_id, - last_agent_message: None, - }, - )), - ]; - - session - .record_initial_history(InitialHistory::Forked(rollout_items)) - .await; - - assert_eq!( - session.previous_turn_settings().await, - Some(PreviousTurnSettings { - model: previous_model.to_string(), - realtime_active: Some(turn_context.realtime_active), - }) - ); - } - - #[tokio::test] - async fn thread_rollback_drops_last_turn_from_history() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - let rollout_path = attach_rollout_recorder(&sess).await; - - let initial_context = sess.build_initial_context(tc.as_ref()).await; - let turn_1 = vec![ - user_message("turn 1 user"), - assistant_message("turn 1 assistant"), - ]; - let turn_2 = vec![ - user_message("turn 2 user"), - assistant_message("turn 2 assistant"), - ]; - let mut full_history = Vec::new(); - full_history.extend(initial_context.clone()); - full_history.extend(turn_1.clone()); - full_history.extend(turn_2); - sess.replace_history(full_history.clone(), Some(tc.to_turn_context_item())) - .await; - let rollout_items: Vec = full_history - .into_iter() - .map(RolloutItem::ResponseItem) - .collect(); - sess.persist_rollout_items(&rollout_items).await; - sess.set_previous_turn_settings(Some(PreviousTurnSettings { - model: "stale-model".to_string(), - realtime_active: Some(tc.realtime_active), - })) - .await; - { - let mut state = sess.state.lock().await; - state.set_reference_context_item(Some(tc.to_turn_context_item())); - } - - handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; - - let rollback_event = wait_for_thread_rolled_back(&rx).await; - assert_eq!(rollback_event.num_turns, 1); - - let mut expected = Vec::new(); - expected.extend(initial_context); - expected.extend(turn_1); - - let history = sess.clone_history().await; - assert_eq!(expected, history.raw_items()); - assert_eq!(sess.previous_turn_settings().await, None); - assert!(sess.reference_context_item().await.is_none()); - - let InitialHistory::Resumed(resumed) = RolloutRecorder::get_rollout_history(&rollout_path) - .await - .expect("read rollout history") - else { - panic!("expected resumed rollout history"); - }; - assert!(resumed.history.iter().any(|item| { - matches!( - item, - RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) - if rollback.num_turns == 1 - ) - })); - } - - #[tokio::test] - async fn thread_rollback_clears_history_when_num_turns_exceeds_existing_turns() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - attach_rollout_recorder(&sess).await; - - let initial_context = sess.build_initial_context(tc.as_ref()).await; - let turn_1 = vec![user_message("turn 1 user")]; - let mut full_history = Vec::new(); - full_history.extend(initial_context.clone()); - full_history.extend(turn_1); - sess.replace_history(full_history.clone(), Some(tc.to_turn_context_item())) - .await; - let rollout_items: Vec = full_history - .into_iter() - .map(RolloutItem::ResponseItem) - .collect(); - sess.persist_rollout_items(&rollout_items).await; - - handlers::thread_rollback(&sess, "sub-1".to_string(), 99).await; - - let rollback_event = wait_for_thread_rolled_back(&rx).await; - assert_eq!(rollback_event.num_turns, 99); - - let history = sess.clone_history().await; - assert_eq!(initial_context, history.raw_items()); - } - - #[tokio::test] - async fn thread_rollback_fails_without_persisted_rollout_path() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - - let initial_context = sess.build_initial_context(tc.as_ref()).await; - sess.record_into_history(&initial_context, tc.as_ref()) - .await; - - handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; - - let error_event = wait_for_thread_rollback_failed(&rx).await; - assert_eq!( - error_event.message, - "thread rollback requires a persisted rollout path" - ); - assert_eq!( - error_event.codex_error_info, - Some(CodexErrorInfo::ThreadRollbackFailed) - ); - assert_eq!(sess.clone_history().await.raw_items(), initial_context); - } - - #[tokio::test] - async fn thread_rollback_recomputes_previous_turn_settings_and_reference_context_from_replay() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - attach_rollout_recorder(&sess).await; - - let first_context_item = tc.to_turn_context_item(); - let first_turn_id = first_context_item - .turn_id - .clone() - .expect("turn context should have turn_id"); - let mut rolled_back_context_item = first_context_item.clone(); - rolled_back_context_item.turn_id = Some("rolled-back-turn".to_string()); - rolled_back_context_item.model = "rolled-back-model".to_string(); - let rolled_back_turn_id = rolled_back_context_item - .turn_id - .clone() - .expect("turn context should have turn_id"); - let turn_one_user = user_message("turn 1 user"); - let turn_one_assistant = assistant_message("turn 1 assistant"); - let turn_two_user = user_message("turn 2 user"); - let turn_two_assistant = assistant_message("turn 2 assistant"); - - sess.persist_rollout_items(&[ - RolloutItem::EventMsg(EventMsg::TurnStarted( - codex_protocol::protocol::TurnStartedEvent { - turn_id: first_turn_id.clone(), - model_context_window: Some(128_000), - collaboration_mode_kind: ModeKind::Default, - }, - )), - RolloutItem::EventMsg(EventMsg::UserMessage( - codex_protocol::protocol::UserMessageEvent { - message: "turn 1 user".to_string(), - images: None, - local_images: Vec::new(), - text_elements: Vec::new(), - }, - )), - RolloutItem::TurnContext(first_context_item.clone()), - RolloutItem::ResponseItem(turn_one_user.clone()), - RolloutItem::ResponseItem(turn_one_assistant.clone()), - RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { - turn_id: first_turn_id, - last_agent_message: None, - })), - RolloutItem::EventMsg(EventMsg::TurnStarted( - codex_protocol::protocol::TurnStartedEvent { - turn_id: rolled_back_turn_id.clone(), - model_context_window: Some(128_000), - collaboration_mode_kind: ModeKind::Default, - }, - )), - RolloutItem::EventMsg(EventMsg::UserMessage( - codex_protocol::protocol::UserMessageEvent { - message: "turn 2 user".to_string(), - images: None, - local_images: Vec::new(), - text_elements: Vec::new(), - }, - )), - RolloutItem::TurnContext(rolled_back_context_item), - RolloutItem::ResponseItem(turn_two_user), - RolloutItem::ResponseItem(turn_two_assistant), - RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { - turn_id: rolled_back_turn_id, - last_agent_message: None, - })), - ]) - .await; - sess.replace_history( - vec![assistant_message("stale history")], - Some(first_context_item.clone()), - ) - .await; - sess.set_previous_turn_settings(Some(PreviousTurnSettings { - model: "stale-model".to_string(), - realtime_active: None, - })) - .await; - - handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; - let rollback_event = wait_for_thread_rolled_back(&rx).await; - assert_eq!(rollback_event.num_turns, 1); - - assert_eq!( - sess.clone_history().await.raw_items(), - vec![turn_one_user, turn_one_assistant] - ); - assert_eq!( - sess.previous_turn_settings().await, - Some(PreviousTurnSettings { - model: tc.model_info.slug.clone(), - realtime_active: Some(tc.realtime_active), - }) - ); - assert_eq!( - serde_json::to_value(sess.reference_context_item().await) - .expect("serialize replay reference context item"), - serde_json::to_value(Some(first_context_item)) - .expect("serialize expected reference context item") - ); - } - - #[tokio::test] - async fn thread_rollback_persists_marker_and_replays_cumulatively() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - let rollout_path = attach_rollout_recorder(&sess).await; - let turn_context_item = tc.to_turn_context_item(); - - sess.persist_rollout_items(&[ - RolloutItem::EventMsg(EventMsg::TurnStarted( - codex_protocol::protocol::TurnStartedEvent { - turn_id: "turn-1".to_string(), - model_context_window: Some(128_000), - collaboration_mode_kind: ModeKind::Default, - }, - )), - RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent { - message: "turn 1 user".to_string(), - images: None, - local_images: Vec::new(), - text_elements: Vec::new(), - })), - RolloutItem::TurnContext(turn_context_item.clone()), - RolloutItem::ResponseItem(user_message("turn 1 user")), - RolloutItem::ResponseItem(assistant_message("turn 1 assistant")), - RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { - turn_id: "turn-1".to_string(), - last_agent_message: None, - })), - RolloutItem::EventMsg(EventMsg::TurnStarted( - codex_protocol::protocol::TurnStartedEvent { - turn_id: "turn-2".to_string(), - model_context_window: Some(128_000), - collaboration_mode_kind: ModeKind::Default, - }, - )), - RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent { - message: "turn 2 user".to_string(), - images: None, - local_images: Vec::new(), - text_elements: Vec::new(), - })), - RolloutItem::TurnContext(turn_context_item.clone()), - RolloutItem::ResponseItem(user_message("turn 2 user")), - RolloutItem::ResponseItem(assistant_message("turn 2 assistant")), - RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { - turn_id: "turn-2".to_string(), - last_agent_message: None, - })), - RolloutItem::EventMsg(EventMsg::TurnStarted( - codex_protocol::protocol::TurnStartedEvent { - turn_id: "turn-3".to_string(), - model_context_window: Some(128_000), - collaboration_mode_kind: ModeKind::Default, - }, - )), - RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent { - message: "turn 3 user".to_string(), - images: None, - local_images: Vec::new(), - text_elements: Vec::new(), - })), - RolloutItem::TurnContext(turn_context_item), - RolloutItem::ResponseItem(user_message("turn 3 user")), - RolloutItem::ResponseItem(assistant_message("turn 3 assistant")), - RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { - turn_id: "turn-3".to_string(), - last_agent_message: None, - })), - ]) - .await; - - handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; - let first_rollback = wait_for_thread_rolled_back(&rx).await; - assert_eq!(first_rollback.num_turns, 1); - handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; - let second_rollback = wait_for_thread_rolled_back(&rx).await; - assert_eq!(second_rollback.num_turns, 1); - - assert_eq!( - sess.clone_history().await.raw_items(), - vec![ - user_message("turn 1 user"), - assistant_message("turn 1 assistant") - ] - ); - - let InitialHistory::Resumed(resumed) = RolloutRecorder::get_rollout_history(&rollout_path) - .await - .expect("read rollout history") - else { - panic!("expected resumed rollout history"); - }; - let rollback_markers = resumed - .history - .iter() - .filter(|item| matches!(item, RolloutItem::EventMsg(EventMsg::ThreadRolledBack(_)))) - .count(); - assert_eq!(rollback_markers, 2); - } - - #[tokio::test] - async fn thread_rollback_fails_when_turn_in_progress() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - - let initial_context = sess.build_initial_context(tc.as_ref()).await; - sess.record_into_history(&initial_context, tc.as_ref()) - .await; - - *sess.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); - handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; - - let error_event = wait_for_thread_rollback_failed(&rx).await; - assert_eq!( - error_event.codex_error_info, - Some(CodexErrorInfo::ThreadRollbackFailed) - ); - - let history = sess.clone_history().await; - assert_eq!(initial_context, history.raw_items()); - } - - #[tokio::test] - async fn thread_rollback_fails_when_num_turns_is_zero() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - - let initial_context = sess.build_initial_context(tc.as_ref()).await; - sess.record_into_history(&initial_context, tc.as_ref()) - .await; - - handlers::thread_rollback(&sess, "sub-1".to_string(), 0).await; - - let error_event = wait_for_thread_rollback_failed(&rx).await; - assert_eq!(error_event.message, "num_turns must be >= 1"); - assert_eq!( - error_event.codex_error_info, - Some(CodexErrorInfo::ThreadRollbackFailed) - ); - - let history = sess.clone_history().await; - assert_eq!(initial_context, history.raw_items()); - } - - #[tokio::test] - async fn set_rate_limits_retains_previous_credits() { - let codex_home = tempfile::tempdir().expect("create temp dir"); - let config = build_test_config(codex_home.path()).await; - let config = Arc::new(config); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = - ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); - let reasoning_effort = config.model_reasoning_effort; - let collaboration_mode = CollaborationMode { - mode: ModeKind::Default, - settings: Settings { - model, - reasoning_effort, - developer_instructions: None, - }, - }; - let session_configuration = SessionConfiguration { - provider: config.model_provider.clone(), - collaboration_mode, - model_reasoning_summary: config.model_reasoning_summary, - developer_instructions: config.developer_instructions.clone(), - user_instructions: config.user_instructions.clone(), - service_tier: None, - personality: config.personality, - base_instructions: config - .base_instructions - .clone() - .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), - compact_prompt: config.compact_prompt.clone(), - approval_policy: config.permissions.approval_policy.clone(), - sandbox_policy: config.permissions.sandbox_policy.clone(), - windows_sandbox_level: WindowsSandboxLevel::from_config(&config), - cwd: config.cwd.clone(), - codex_home: config.codex_home.clone(), - thread_name: None, - original_config_do_not_use: Arc::clone(&config), - metrics_service_name: None, - app_server_client_name: None, - session_source: SessionSource::Exec, - dynamic_tools: Vec::new(), - persist_extended_history: false, - inherited_shell_snapshot: None, - }; - - let mut state = SessionState::new(session_configuration); - let initial = RateLimitSnapshot { - limit_id: None, - limit_name: None, - primary: Some(RateLimitWindow { - used_percent: 10.0, - window_minutes: Some(15), - resets_at: Some(1_700), - }), - secondary: None, - credits: Some(CreditsSnapshot { - has_credits: true, - unlimited: false, - balance: Some("10.00".to_string()), - }), - plan_type: Some(codex_protocol::account::PlanType::Plus), - }; - state.set_rate_limits(initial.clone()); - - let update = RateLimitSnapshot { - limit_id: Some("codex_other".to_string()), - limit_name: Some("codex_other".to_string()), - primary: Some(RateLimitWindow { - used_percent: 40.0, - window_minutes: Some(30), - resets_at: Some(1_800), - }), - secondary: Some(RateLimitWindow { - used_percent: 5.0, - window_minutes: Some(60), - resets_at: Some(1_900), - }), - credits: None, - plan_type: None, - }; - state.set_rate_limits(update.clone()); - - assert_eq!( - state.latest_rate_limits, - Some(RateLimitSnapshot { - limit_id: Some("codex_other".to_string()), - limit_name: Some("codex_other".to_string()), - primary: update.primary.clone(), - secondary: update.secondary, - credits: initial.credits, - plan_type: initial.plan_type, - }) - ); - } - - #[tokio::test] - async fn set_rate_limits_updates_plan_type_when_present() { - let codex_home = tempfile::tempdir().expect("create temp dir"); - let config = build_test_config(codex_home.path()).await; - let config = Arc::new(config); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = - ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); - let reasoning_effort = config.model_reasoning_effort; - let collaboration_mode = CollaborationMode { - mode: ModeKind::Default, - settings: Settings { - model, - reasoning_effort, - developer_instructions: None, - }, - }; - let session_configuration = SessionConfiguration { - provider: config.model_provider.clone(), - collaboration_mode, - model_reasoning_summary: config.model_reasoning_summary, - developer_instructions: config.developer_instructions.clone(), - user_instructions: config.user_instructions.clone(), - service_tier: None, - personality: config.personality, - base_instructions: config - .base_instructions - .clone() - .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), - compact_prompt: config.compact_prompt.clone(), - approval_policy: config.permissions.approval_policy.clone(), - sandbox_policy: config.permissions.sandbox_policy.clone(), - windows_sandbox_level: WindowsSandboxLevel::from_config(&config), - cwd: config.cwd.clone(), - codex_home: config.codex_home.clone(), - thread_name: None, - original_config_do_not_use: Arc::clone(&config), - metrics_service_name: None, - app_server_client_name: None, - session_source: SessionSource::Exec, - dynamic_tools: Vec::new(), - persist_extended_history: false, - inherited_shell_snapshot: None, - }; - - let mut state = SessionState::new(session_configuration); - let initial = RateLimitSnapshot { - limit_id: None, - limit_name: None, - primary: Some(RateLimitWindow { - used_percent: 15.0, - window_minutes: Some(20), - resets_at: Some(1_600), - }), - secondary: Some(RateLimitWindow { - used_percent: 5.0, - window_minutes: Some(45), - resets_at: Some(1_650), - }), - credits: Some(CreditsSnapshot { - has_credits: true, - unlimited: false, - balance: Some("15.00".to_string()), - }), - plan_type: Some(codex_protocol::account::PlanType::Plus), - }; - state.set_rate_limits(initial.clone()); - - let update = RateLimitSnapshot { - limit_id: None, - limit_name: None, - primary: Some(RateLimitWindow { - used_percent: 35.0, - window_minutes: Some(25), - resets_at: Some(1_700), - }), - secondary: None, - credits: None, - plan_type: Some(codex_protocol::account::PlanType::Pro), - }; - state.set_rate_limits(update.clone()); - - assert_eq!( - state.latest_rate_limits, - Some(RateLimitSnapshot { - limit_id: Some("codex".to_string()), - limit_name: None, - primary: update.primary, - secondary: update.secondary, - credits: initial.credits, - plan_type: update.plan_type, - }) - ); - } - - #[test] - fn prefers_structured_content_when_present() { - let ctr = McpCallToolResult { - // Content present but should be ignored because structured_content is set. - content: vec![text_block("ignored")], - is_error: None, - structured_content: Some(json!({ - "ok": true, - "value": 42 - })), - meta: None, - }; - - let got = FunctionCallOutputPayload::from(&ctr); - let expected = FunctionCallOutputPayload { - body: FunctionCallOutputBody::Text( - serde_json::to_string(&json!({ - "ok": true, - "value": 42 - })) - .unwrap(), - ), - success: Some(true), - }; - - assert_eq!(expected, got); - } - - #[tokio::test] - async fn includes_timed_out_message() { - let exec = ExecToolCallOutput { - exit_code: 0, - stdout: StreamOutput::new(String::new()), - stderr: StreamOutput::new(String::new()), - aggregated_output: StreamOutput::new("Command output".to_string()), - duration: StdDuration::from_secs(1), - timed_out: true, - }; - let (_, turn_context) = make_session_and_context().await; - - let out = format_exec_output_str(&exec, turn_context.truncation_policy); - - assert_eq!( - out, - "command timed out after 1000 milliseconds\nCommand output" - ); - } - - #[tokio::test] - async fn turn_context_with_model_updates_model_fields() { - let (session, mut turn_context) = make_session_and_context().await; - turn_context.reasoning_effort = Some(ReasoningEffortConfig::Minimal); - let updated = turn_context - .with_model("gpt-5.1".to_string(), &session.services.models_manager) - .await; - let expected_model_info = session - .services - .models_manager - .get_model_info("gpt-5.1", updated.config.as_ref()) - .await; - - assert_eq!(updated.config.model.as_deref(), Some("gpt-5.1")); - assert_eq!(updated.collaboration_mode.model(), "gpt-5.1"); - assert_eq!(updated.model_info, expected_model_info); - assert_eq!( - updated.reasoning_effort, - Some(ReasoningEffortConfig::Medium) - ); - assert_eq!( - updated.collaboration_mode.reasoning_effort(), - Some(ReasoningEffortConfig::Medium) - ); - assert_eq!( - updated.config.model_reasoning_effort, - Some(ReasoningEffortConfig::Medium) - ); - assert_eq!( - updated.truncation_policy, - expected_model_info.truncation_policy.into() - ); - assert!(!Arc::ptr_eq( - &updated.tool_call_gate, - &turn_context.tool_call_gate - )); - } - - #[test] - fn falls_back_to_content_when_structured_is_null() { - let ctr = McpCallToolResult { - content: vec![text_block("hello"), text_block("world")], - is_error: None, - structured_content: Some(serde_json::Value::Null), - meta: None, - }; - - let got = FunctionCallOutputPayload::from(&ctr); - let expected = FunctionCallOutputPayload { - body: FunctionCallOutputBody::Text( - serde_json::to_string(&vec![text_block("hello"), text_block("world")]).unwrap(), - ), - success: Some(true), - }; - - assert_eq!(expected, got); - } - - #[test] - fn success_flag_reflects_is_error_true() { - let ctr = McpCallToolResult { - content: vec![text_block("unused")], - is_error: Some(true), - structured_content: Some(json!({ "message": "bad" })), - meta: None, - }; - - let got = FunctionCallOutputPayload::from(&ctr); - let expected = FunctionCallOutputPayload { - body: FunctionCallOutputBody::Text( - serde_json::to_string(&json!({ "message": "bad" })).unwrap(), - ), - success: Some(false), - }; - - assert_eq!(expected, got); - } - - #[test] - fn success_flag_true_with_no_error_and_content_used() { - let ctr = McpCallToolResult { - content: vec![text_block("alpha")], - is_error: Some(false), - structured_content: None, - meta: None, - }; - - let got = FunctionCallOutputPayload::from(&ctr); - let expected = FunctionCallOutputPayload { - body: FunctionCallOutputBody::Text( - serde_json::to_string(&vec![text_block("alpha")]).unwrap(), - ), - success: Some(true), - }; - - assert_eq!(expected, got); - } - - async fn wait_for_thread_rolled_back( - rx: &async_channel::Receiver, - ) -> crate::protocol::ThreadRolledBackEvent { - let deadline = StdDuration::from_secs(2); - let start = std::time::Instant::now(); - loop { - let remaining = deadline.saturating_sub(start.elapsed()); - let evt = tokio::time::timeout(remaining, rx.recv()) - .await - .expect("timeout waiting for event") - .expect("event"); - match evt.msg { - EventMsg::ThreadRolledBack(payload) => return payload, - _ => continue, - } - } - } - - async fn wait_for_thread_rollback_failed(rx: &async_channel::Receiver) -> ErrorEvent { - let deadline = StdDuration::from_secs(2); - let start = std::time::Instant::now(); - loop { - let remaining = deadline.saturating_sub(start.elapsed()); - let evt = tokio::time::timeout(remaining, rx.recv()) - .await - .expect("timeout waiting for event") - .expect("event"); - match evt.msg { - EventMsg::Error(payload) - if payload.codex_error_info == Some(CodexErrorInfo::ThreadRollbackFailed) => - { - return payload; - } - _ => continue, - } - } - } - - async fn attach_rollout_recorder(session: &Arc) -> PathBuf { - let config = session.get_config().await; - let recorder = RolloutRecorder::new( - config.as_ref(), - RolloutRecorderParams::new( - ThreadId::default(), - None, - SessionSource::Exec, - BaseInstructions::default(), - Vec::new(), - EventPersistenceMode::Limited, - ), - None, - None, - ) - .await - .expect("create rollout recorder"); - let rollout_path = recorder.rollout_path().to_path_buf(); - { - let mut rollout = session.services.rollout.lock().await; - *rollout = Some(recorder); - } - session.ensure_rollout_materialized().await; - session.flush_rollout().await; - rollout_path - } - - fn text_block(s: &str) -> serde_json::Value { - json!({ - "type": "text", - "text": s, - }) - } - - fn init_test_tracing() { - static INIT: Once = Once::new(); - INIT.call_once(|| { - let provider = SdkTracerProvider::builder().build(); - let tracer = provider.tracer("codex-core-tests"); - let subscriber = tracing_subscriber::registry() - .with(tracing_opentelemetry::layer().with_tracer(tracer)); - tracing::subscriber::set_global_default(subscriber) - .expect("global tracing subscriber should only be installed once"); - }); - } - - async fn build_test_config(codex_home: &Path) -> Config { - ConfigBuilder::default() - .codex_home(codex_home.to_path_buf()) - .build() - .await - .expect("load default test config") - } - - fn otel_manager( - conversation_id: ThreadId, - config: &Config, - model_info: &ModelInfo, - session_source: SessionSource, - ) -> OtelManager { - OtelManager::new( - conversation_id, - ModelsManager::get_model_offline_for_tests(config.model.as_deref()).as_str(), - model_info.slug.as_str(), - None, - Some("test@test.com".to_string()), - Some(TelemetryAuthMode::Chatgpt), - "test_originator".to_string(), - false, - "test".to_string(), - session_source, - ) - } - - pub(crate) async fn make_session_configuration_for_tests() -> SessionConfiguration { - let codex_home = tempfile::tempdir().expect("create temp dir"); - let config = build_test_config(codex_home.path()).await; - let config = Arc::new(config); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = - ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); - let reasoning_effort = config.model_reasoning_effort; - let collaboration_mode = CollaborationMode { - mode: ModeKind::Default, - settings: Settings { - model, - reasoning_effort, - developer_instructions: None, - }, - }; - - SessionConfiguration { - provider: config.model_provider.clone(), - collaboration_mode, - model_reasoning_summary: config.model_reasoning_summary, - developer_instructions: config.developer_instructions.clone(), - user_instructions: config.user_instructions.clone(), - service_tier: None, - personality: config.personality, - base_instructions: config - .base_instructions - .clone() - .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), - compact_prompt: config.compact_prompt.clone(), - approval_policy: config.permissions.approval_policy.clone(), - sandbox_policy: config.permissions.sandbox_policy.clone(), - windows_sandbox_level: WindowsSandboxLevel::from_config(&config), - cwd: config.cwd.clone(), - codex_home: config.codex_home.clone(), - thread_name: None, - original_config_do_not_use: Arc::clone(&config), - metrics_service_name: None, - app_server_client_name: None, - session_source: SessionSource::Exec, - dynamic_tools: Vec::new(), - persist_extended_history: false, - inherited_shell_snapshot: None, - } - } - - #[tokio::test] - async fn session_new_fails_when_zsh_fork_enabled_without_zsh_path() { - let codex_home = tempfile::tempdir().expect("create temp dir"); - let mut config = build_test_config(codex_home.path()).await; - config - .features - .enable(Feature::ShellZshFork) - .expect("test config should allow shell_zsh_fork"); - config.zsh_path = None; - let config = Arc::new(config); - - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let models_manager = Arc::new(ModelsManager::new( - config.codex_home.clone(), - auth_manager.clone(), - None, - CollaborationModesConfig::default(), - )); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = - ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); - let collaboration_mode = CollaborationMode { - mode: ModeKind::Default, - settings: Settings { - model, - reasoning_effort: config.model_reasoning_effort, - developer_instructions: None, - }, - }; - let session_configuration = SessionConfiguration { - provider: config.model_provider.clone(), - collaboration_mode, - model_reasoning_summary: config.model_reasoning_summary, - developer_instructions: config.developer_instructions.clone(), - user_instructions: config.user_instructions.clone(), - service_tier: None, - personality: config.personality, - base_instructions: config - .base_instructions - .clone() - .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), - compact_prompt: config.compact_prompt.clone(), - approval_policy: config.permissions.approval_policy.clone(), - sandbox_policy: config.permissions.sandbox_policy.clone(), - windows_sandbox_level: WindowsSandboxLevel::from_config(&config), - cwd: config.cwd.clone(), - codex_home: config.codex_home.clone(), - thread_name: None, - original_config_do_not_use: Arc::clone(&config), - metrics_service_name: None, - app_server_client_name: None, - session_source: SessionSource::Exec, - dynamic_tools: Vec::new(), - persist_extended_history: false, - inherited_shell_snapshot: None, - }; - - let (tx_event, _rx_event) = async_channel::unbounded(); - let (agent_status_tx, _agent_status_rx) = watch::channel(AgentStatus::PendingInit); - let plugins_manager = Arc::new(PluginsManager::new(config.codex_home.clone())); - let mcp_manager = Arc::new(McpManager::new(Arc::clone(&plugins_manager))); - let skills_manager = Arc::new(SkillsManager::new( - config.codex_home.clone(), - Arc::clone(&plugins_manager), - )); - let result = Session::new( - session_configuration, - Arc::clone(&config), - auth_manager, - models_manager, - ExecPolicyManager::default(), - tx_event, - agent_status_tx, - InitialHistory::New, - SessionSource::Exec, - skills_manager, - plugins_manager, - mcp_manager, - Arc::new(FileWatcher::noop()), - AgentControl::default(), - ) - .await; - - let err = match result { - Ok(_) => panic!("expected startup to fail"), - Err(err) => err, - }; - let msg = format!("{err:#}"); - assert!(msg.contains("zsh fork feature enabled, but `zsh_path` is not configured")); - } - - // todo: use online model info - pub(crate) async fn make_session_and_context() -> (Session, TurnContext) { - let (tx_event, _rx_event) = async_channel::unbounded(); - let codex_home = tempfile::tempdir().expect("create temp dir"); - let config = build_test_config(codex_home.path()).await; - let config = Arc::new(config); - let conversation_id = ThreadId::default(); - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let models_manager = Arc::new(ModelsManager::new( - config.codex_home.clone(), - auth_manager.clone(), - None, - CollaborationModesConfig::default(), - )); - let agent_control = AgentControl::default(); - let exec_policy = ExecPolicyManager::default(); - let (agent_status_tx, _agent_status_rx) = watch::channel(AgentStatus::PendingInit); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = - ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); - let reasoning_effort = config.model_reasoning_effort; - let collaboration_mode = CollaborationMode { - mode: ModeKind::Default, - settings: Settings { - model, - reasoning_effort, - developer_instructions: None, - }, - }; - let session_configuration = SessionConfiguration { - provider: config.model_provider.clone(), - collaboration_mode, - model_reasoning_summary: config.model_reasoning_summary, - developer_instructions: config.developer_instructions.clone(), - user_instructions: config.user_instructions.clone(), - service_tier: None, - personality: config.personality, - base_instructions: config - .base_instructions - .clone() - .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), - compact_prompt: config.compact_prompt.clone(), - approval_policy: config.permissions.approval_policy.clone(), - sandbox_policy: config.permissions.sandbox_policy.clone(), - windows_sandbox_level: WindowsSandboxLevel::from_config(&config), - cwd: config.cwd.clone(), - codex_home: config.codex_home.clone(), - thread_name: None, - original_config_do_not_use: Arc::clone(&config), - metrics_service_name: None, - app_server_client_name: None, - session_source: SessionSource::Exec, - dynamic_tools: Vec::new(), - persist_extended_history: false, - inherited_shell_snapshot: None, - }; - let per_turn_config = Session::build_per_turn_config(&session_configuration); - let model_info = ModelsManager::construct_model_info_offline_for_tests( - session_configuration.collaboration_mode.model(), - &per_turn_config, - ); - let otel_manager = otel_manager( - conversation_id, - config.as_ref(), - &model_info, - session_configuration.session_source.clone(), - ); - - let state = SessionState::new(session_configuration.clone()); - let plugins_manager = Arc::new(PluginsManager::new(config.codex_home.clone())); - let mcp_manager = Arc::new(McpManager::new(Arc::clone(&plugins_manager))); - let skills_manager = Arc::new(SkillsManager::new( - config.codex_home.clone(), - Arc::clone(&plugins_manager), - )); - let network_approval = Arc::new(NetworkApprovalService::default()); - - let file_watcher = Arc::new(FileWatcher::noop()); - let services = SessionServices { - mcp_connection_manager: Arc::new(RwLock::new( - McpConnectionManager::new_mcp_connection_manager_for_tests( - &config.permissions.approval_policy, - ), - )), - mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()), - unified_exec_manager: UnifiedExecProcessManager::new( - config.background_terminal_max_timeout, - ), - shell_zsh_path: None, - main_execve_wrapper_exe: config.main_execve_wrapper_exe.clone(), - analytics_events_client: AnalyticsEventsClient::new( - Arc::clone(&config), - Arc::clone(&auth_manager), - ), - hooks: Hooks::new(HooksConfig { - legacy_notify_argv: config.notify.clone(), - }), - rollout: Mutex::new(None), - user_shell: Arc::new(default_user_shell()), - shell_snapshot_tx: watch::channel(None).0, - show_raw_agent_reasoning: config.show_raw_agent_reasoning, - exec_policy, - auth_manager: auth_manager.clone(), - otel_manager: otel_manager.clone(), - models_manager: Arc::clone(&models_manager), - tool_approvals: Mutex::new(ApprovalStore::default()), - execve_session_approvals: RwLock::new(HashMap::new()), - skills_manager, - plugins_manager, - mcp_manager, - file_watcher, - agent_control, - network_proxy: None, - network_approval: Arc::clone(&network_approval), - state_db: None, - model_client: ModelClient::new( - Some(auth_manager.clone()), - conversation_id, - session_configuration.provider.clone(), - session_configuration.session_source.clone(), - config.model_verbosity, - ws_version_from_features(config.as_ref()), - config.features.enabled(Feature::EnableRequestCompression), - config.features.enabled(Feature::RuntimeMetrics), - Session::build_model_client_beta_features_header(config.as_ref()), - ), - }; - let js_repl = Arc::new(JsReplHandle::with_node_path( - config.js_repl_node_path.clone(), - config.js_repl_node_module_dirs.clone(), - )); - - let skills_outcome = Arc::new(services.skills_manager.skills_for_config(&per_turn_config)); - let turn_context = Session::make_turn_context( - Some(Arc::clone(&auth_manager)), - &otel_manager, - session_configuration.provider.clone(), - &session_configuration, - per_turn_config, - model_info, - None, - "turn_id".to_string(), - Arc::clone(&js_repl), - skills_outcome, - ); - - let session = Session { - conversation_id, - tx_event, - agent_status: agent_status_tx, - state: Mutex::new(state), - features: config.features.clone(), - pending_mcp_server_refresh_config: Mutex::new(None), - conversation: Arc::new(RealtimeConversationManager::new()), - active_turn: Mutex::new(None), - services, - js_repl, - next_internal_sub_id: AtomicU64::new(0), - }; - - (session, turn_context) - } - - #[tokio::test] - async fn submit_with_id_captures_current_span_trace_context() { - let (session, _turn_context) = make_session_and_context().await; - let (tx_sub, rx_sub) = async_channel::bounded(1); - let (_tx_event, rx_event) = async_channel::unbounded(); - let (_agent_status_tx, agent_status) = watch::channel(AgentStatus::PendingInit); - let codex = Codex { - tx_sub, - rx_event, - agent_status, - session: Arc::new(session), - }; - - init_test_tracing(); - - let request_parent = W3cTraceContext { - traceparent: Some("00-00000000000000000000000000000011-0000000000000022-01".into()), - tracestate: Some("vendor=value".into()), - }; - let request_span = info_span!("app_server.request"); - assert!(set_parent_from_w3c_trace_context( - &request_span, - &request_parent - )); - - let expected_trace = async { - let expected_trace = - current_span_w3c_trace_context().expect("current span should have trace context"); - codex - .submit_with_id(Submission { - id: "sub-1".into(), - op: Op::Interrupt, - trace: None, - }) - .await - .expect("submit should succeed"); - expected_trace - } - .instrument(request_span) - .await; - - let submitted = rx_sub.recv().await.expect("submission"); - assert_eq!(submitted.trace, Some(expected_trace)); - } - - #[tokio::test] - async fn new_default_turn_captures_current_span_trace_id() { - let (session, _turn_context) = make_session_and_context().await; - - init_test_tracing(); - - let request_parent = W3cTraceContext { - traceparent: Some("00-00000000000000000000000000000011-0000000000000022-01".into()), - tracestate: Some("vendor=value".into()), - }; - let request_span = info_span!("app_server.request"); - assert!(set_parent_from_w3c_trace_context( - &request_span, - &request_parent - )); - - let turn_context_item = async { - let expected_trace_id = Span::current() - .context() - .span() - .span_context() - .trace_id() - .to_string(); - let turn_context = session.new_default_turn().await; - let turn_context_item = turn_context.to_turn_context_item(); - assert_eq!(turn_context_item.trace_id, Some(expected_trace_id)); - turn_context_item - } - .instrument(request_span) - .await; - - assert_eq!( - turn_context_item.trace_id.as_deref(), - Some("00000000000000000000000000000011") - ); - } - - #[test] - fn submission_dispatch_span_prefers_submission_trace_context() { - init_test_tracing(); - - let ambient_parent = W3cTraceContext { - traceparent: Some("00-00000000000000000000000000000033-0000000000000044-01".into()), - tracestate: None, - }; - let ambient_span = info_span!("ambient"); - assert!(set_parent_from_w3c_trace_context( - &ambient_span, - &ambient_parent - )); - - let submission_trace = W3cTraceContext { - traceparent: Some("00-00000000000000000000000000000055-0000000000000066-01".into()), - tracestate: Some("vendor=value".into()), - }; - let dispatch_span = ambient_span.in_scope(|| { - submission_dispatch_span(&Submission { - id: "sub-1".into(), - op: Op::Interrupt, - trace: Some(submission_trace), - }) - }); - - let trace_id = dispatch_span.context().span().span_context().trace_id(); - assert_eq!( - trace_id, - TraceId::from_hex("00000000000000000000000000000055").expect("trace id") - ); - } - - #[test] - fn submission_dispatch_span_uses_debug_for_realtime_audio() { - init_test_tracing(); - - let dispatch_span = submission_dispatch_span(&Submission { - id: "sub-1".into(), - op: Op::RealtimeConversationAudio(ConversationAudioParams { - frame: RealtimeAudioFrame { - data: "ZmFrZQ==".into(), - sample_rate: 16_000, - num_channels: 1, - samples_per_channel: Some(160), - }, - }), - trace: None, - }); - - assert_eq!( - dispatch_span.metadata().expect("span metadata").level(), - &tracing::Level::DEBUG - ); - } - - #[tokio::test] - async fn spawn_task_turn_span_inherits_dispatch_trace_context() { - struct TraceCaptureTask { - captured_trace: Arc>>, - } - - #[async_trait::async_trait] - impl SessionTask for TraceCaptureTask { - fn kind(&self) -> TaskKind { - TaskKind::Regular - } - - fn span_name(&self) -> &'static str { - "session_task.trace_capture" - } - - async fn run( - self: Arc, - _session: Arc, - _ctx: Arc, - _input: Vec, - _cancellation_token: CancellationToken, - ) -> Option { - let mut trace = self - .captured_trace - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - *trace = current_span_w3c_trace_context(); - None - } - } - - init_test_tracing(); - - let request_parent = W3cTraceContext { - traceparent: Some("00-00000000000000000000000000000011-0000000000000022-01".into()), - tracestate: Some("vendor=value".into()), - }; - let request_span = tracing::info_span!("app_server.request"); - assert!(set_parent_from_w3c_trace_context( - &request_span, - &request_parent - )); - - let submission_trace = async { - current_span_w3c_trace_context().expect("request span should have trace context") - } - .instrument(request_span) - .await; - - let dispatch_span = submission_dispatch_span(&Submission { - id: "sub-1".into(), - op: Op::Interrupt, - trace: Some(submission_trace.clone()), - }); - let dispatch_span_id = dispatch_span.context().span().span_context().span_id(); - - let (sess, tc, rx) = make_session_and_context_with_rx().await; - let captured_trace = Arc::new(std::sync::Mutex::new(None)); - - async { - sess.spawn_task( - Arc::clone(&tc), - vec![UserInput::Text { - text: "hello".to_string(), - text_elements: Vec::new(), - }], - TraceCaptureTask { - captured_trace: Arc::clone(&captured_trace), - }, - ) - .await; - } - .instrument(dispatch_span) - .await; - - let evt = tokio::time::timeout(StdDuration::from_secs(2), rx.recv()) - .await - .expect("timeout waiting for turn completion") - .expect("event"); - assert!(matches!(evt.msg, EventMsg::TurnComplete(_))); - - let task_trace = captured_trace - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .clone() - .expect("turn task should capture the current span trace context"); - let submission_context = - codex_otel::context_from_w3c_trace_context(&submission_trace).expect("submission"); - let task_context = - codex_otel::context_from_w3c_trace_context(&task_trace).expect("task trace"); - - assert_eq!( - task_context.span().span_context().trace_id(), - submission_context.span().span_context().trace_id() - ); - assert_ne!( - task_context.span().span_context().span_id(), - dispatch_span_id - ); - } - - pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx( - dynamic_tools: Vec, - ) -> ( - Arc, - Arc, - async_channel::Receiver, - ) { - let (tx_event, rx_event) = async_channel::unbounded(); - let codex_home = tempfile::tempdir().expect("create temp dir"); - let config = build_test_config(codex_home.path()).await; - let config = Arc::new(config); - let conversation_id = ThreadId::default(); - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let models_manager = Arc::new(ModelsManager::new( - config.codex_home.clone(), - auth_manager.clone(), - None, - CollaborationModesConfig::default(), - )); - let agent_control = AgentControl::default(); - let exec_policy = ExecPolicyManager::default(); - let (agent_status_tx, _agent_status_rx) = watch::channel(AgentStatus::PendingInit); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = - ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); - let reasoning_effort = config.model_reasoning_effort; - let collaboration_mode = CollaborationMode { - mode: ModeKind::Default, - settings: Settings { - model, - reasoning_effort, - developer_instructions: None, - }, - }; - let session_configuration = SessionConfiguration { - provider: config.model_provider.clone(), - collaboration_mode, - model_reasoning_summary: config.model_reasoning_summary, - developer_instructions: config.developer_instructions.clone(), - user_instructions: config.user_instructions.clone(), - service_tier: None, - personality: config.personality, - base_instructions: config - .base_instructions - .clone() - .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), - compact_prompt: config.compact_prompt.clone(), - approval_policy: config.permissions.approval_policy.clone(), - sandbox_policy: config.permissions.sandbox_policy.clone(), - windows_sandbox_level: WindowsSandboxLevel::from_config(&config), - cwd: config.cwd.clone(), - codex_home: config.codex_home.clone(), - thread_name: None, - original_config_do_not_use: Arc::clone(&config), - metrics_service_name: None, - app_server_client_name: None, - session_source: SessionSource::Exec, - dynamic_tools, - persist_extended_history: false, - inherited_shell_snapshot: None, - }; - let per_turn_config = Session::build_per_turn_config(&session_configuration); - let model_info = ModelsManager::construct_model_info_offline_for_tests( - session_configuration.collaboration_mode.model(), - &per_turn_config, - ); - let otel_manager = otel_manager( - conversation_id, - config.as_ref(), - &model_info, - session_configuration.session_source.clone(), - ); - - let state = SessionState::new(session_configuration.clone()); - let plugins_manager = Arc::new(PluginsManager::new(config.codex_home.clone())); - let mcp_manager = Arc::new(McpManager::new(Arc::clone(&plugins_manager))); - let skills_manager = Arc::new(SkillsManager::new( - config.codex_home.clone(), - Arc::clone(&plugins_manager), - )); - let network_approval = Arc::new(NetworkApprovalService::default()); - - let file_watcher = Arc::new(FileWatcher::noop()); - let services = SessionServices { - mcp_connection_manager: Arc::new(RwLock::new( - McpConnectionManager::new_mcp_connection_manager_for_tests( - &config.permissions.approval_policy, - ), - )), - mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()), - unified_exec_manager: UnifiedExecProcessManager::new( - config.background_terminal_max_timeout, - ), - shell_zsh_path: None, - main_execve_wrapper_exe: config.main_execve_wrapper_exe.clone(), - analytics_events_client: AnalyticsEventsClient::new( - Arc::clone(&config), - Arc::clone(&auth_manager), - ), - hooks: Hooks::new(HooksConfig { - legacy_notify_argv: config.notify.clone(), - }), - rollout: Mutex::new(None), - user_shell: Arc::new(default_user_shell()), - shell_snapshot_tx: watch::channel(None).0, - show_raw_agent_reasoning: config.show_raw_agent_reasoning, - exec_policy, - auth_manager: Arc::clone(&auth_manager), - otel_manager: otel_manager.clone(), - models_manager: Arc::clone(&models_manager), - tool_approvals: Mutex::new(ApprovalStore::default()), - execve_session_approvals: RwLock::new(HashMap::new()), - skills_manager, - plugins_manager, - mcp_manager, - file_watcher, - agent_control, - network_proxy: None, - network_approval: Arc::clone(&network_approval), - state_db: None, - model_client: ModelClient::new( - Some(Arc::clone(&auth_manager)), - conversation_id, - session_configuration.provider.clone(), - session_configuration.session_source.clone(), - config.model_verbosity, - ws_version_from_features(config.as_ref()), - config.features.enabled(Feature::EnableRequestCompression), - config.features.enabled(Feature::RuntimeMetrics), - Session::build_model_client_beta_features_header(config.as_ref()), - ), - }; - let js_repl = Arc::new(JsReplHandle::with_node_path( - config.js_repl_node_path.clone(), - config.js_repl_node_module_dirs.clone(), - )); - - let skills_outcome = Arc::new(services.skills_manager.skills_for_config(&per_turn_config)); - let turn_context = Arc::new(Session::make_turn_context( - Some(Arc::clone(&auth_manager)), - &otel_manager, - session_configuration.provider.clone(), - &session_configuration, - per_turn_config, - model_info, - None, - "turn_id".to_string(), - Arc::clone(&js_repl), - skills_outcome, - )); - - let session = Arc::new(Session { - conversation_id, - tx_event, - agent_status: agent_status_tx, - state: Mutex::new(state), - features: config.features.clone(), - pending_mcp_server_refresh_config: Mutex::new(None), - conversation: Arc::new(RealtimeConversationManager::new()), - active_turn: Mutex::new(None), - services, - js_repl, - next_internal_sub_id: AtomicU64::new(0), - }); - - (session, turn_context, rx_event) - } - - // Like make_session_and_context, but returns Arc and the event receiver - // so tests can assert on emitted events. - pub(crate) async fn make_session_and_context_with_rx() -> ( - Arc, - Arc, - async_channel::Receiver, - ) { - make_session_and_context_with_dynamic_tools_and_rx(Vec::new()).await - } - - #[tokio::test] - async fn refresh_mcp_servers_is_deferred_until_next_turn() { - let (session, turn_context) = make_session_and_context().await; - let old_token = session.mcp_startup_cancellation_token().await; - assert!(!old_token.is_cancelled()); - - let mcp_oauth_credentials_store_mode = - serde_json::to_value(OAuthCredentialsStoreMode::Auto).expect("serialize store mode"); - let refresh_config = McpServerRefreshConfig { - mcp_servers: json!({}), - mcp_oauth_credentials_store_mode, - }; - { - let mut guard = session.pending_mcp_server_refresh_config.lock().await; - *guard = Some(refresh_config); - } - - assert!(!old_token.is_cancelled()); - assert!( - session - .pending_mcp_server_refresh_config - .lock() - .await - .is_some() - ); - - session - .refresh_mcp_servers_if_requested(&turn_context) - .await; - - assert!(old_token.is_cancelled()); - assert!( - session - .pending_mcp_server_refresh_config - .lock() - .await - .is_none() - ); - let new_token = session.mcp_startup_cancellation_token().await; - assert!(!new_token.is_cancelled()); - } - - #[tokio::test] - async fn record_model_warning_appends_user_message() { - let (mut session, turn_context) = make_session_and_context().await; - let features = crate::features::Features::with_defaults().into(); - session.features = features; - - session - .record_model_warning("too many unified exec processes", &turn_context) - .await; - - let history = session.clone_history().await; - let history_items = history.raw_items(); - let last = history_items.last().expect("warning recorded"); - - match last { - ResponseItem::Message { role, content, .. } => { - assert_eq!(role, "user"); - assert_eq!( - content, - &vec![ContentItem::InputText { - text: "Warning: too many unified exec processes".to_string(), - }] - ); - } - other => panic!("expected user message, got {other:?}"), - } - } - - #[tokio::test] - async fn spawn_task_does_not_update_previous_turn_settings_for_non_run_turn_tasks() { - let (sess, tc, _rx) = make_session_and_context_with_rx().await; - sess.set_previous_turn_settings(None).await; - let input = vec![UserInput::Text { - text: "hello".to_string(), - text_elements: Vec::new(), - }]; - - sess.spawn_task( - Arc::clone(&tc), - input, - NeverEndingTask { - kind: TaskKind::Regular, - listen_to_cancellation_token: true, - }, - ) - .await; - - sess.abort_all_tasks(TurnAbortReason::Interrupted).await; - assert_eq!(sess.previous_turn_settings().await, None); - } - - #[tokio::test] - async fn build_settings_update_items_emits_environment_item_for_network_changes() { - let (session, previous_context) = make_session_and_context().await; - let previous_context = Arc::new(previous_context); - let mut current_context = previous_context - .with_model( - previous_context.model_info.slug.clone(), - &session.services.models_manager, - ) - .await; - - let mut config = (*current_context.config).clone(); - let mut requirements = config.config_layer_stack.requirements().clone(); - requirements.network = Some(Sourced::new( - NetworkConstraints { - allowed_domains: Some(vec!["api.example.com".to_string()]), - denied_domains: Some(vec!["blocked.example.com".to_string()]), - ..Default::default() - }, - RequirementSource::CloudRequirements, - )); - let layers = config - .config_layer_stack - .get_layers(ConfigLayerStackOrdering::LowestPrecedenceFirst, true) - .into_iter() - .cloned() - .collect(); - config.config_layer_stack = ConfigLayerStack::new( - layers, - requirements, - config.config_layer_stack.requirements_toml().clone(), - ) - .expect("rebuild config layer stack with network requirements"); - current_context.config = Arc::new(config); - - let reference_context_item = previous_context.to_turn_context_item(); - let update_items = session - .build_settings_update_items(Some(&reference_context_item), ¤t_context) - .await; - - let environment_update = update_items - .iter() - .find_map(|item| match item { - ResponseItem::Message { role, content, .. } if role == "user" => { - let [ContentItem::InputText { text }] = content.as_slice() else { - return None; - }; - text.contains("").then_some(text) - } - _ => None, - }) - .expect("environment update item should be emitted"); - assert!(environment_update.contains("")); - assert!(environment_update.contains("api.example.com")); - assert!(environment_update.contains("blocked.example.com")); - } - - #[tokio::test] - async fn build_settings_update_items_emits_environment_item_for_time_changes() { - let (session, previous_context) = make_session_and_context().await; - let previous_context = Arc::new(previous_context); - let mut current_context = previous_context - .with_model( - previous_context.model_info.slug.clone(), - &session.services.models_manager, - ) - .await; - current_context.current_date = Some("2026-02-27".to_string()); - current_context.timezone = Some("Europe/Berlin".to_string()); - - let reference_context_item = previous_context.to_turn_context_item(); - let update_items = session - .build_settings_update_items(Some(&reference_context_item), ¤t_context) - .await; - - let environment_update = update_items - .iter() - .find_map(|item| match item { - ResponseItem::Message { role, content, .. } if role == "user" => { - let [ContentItem::InputText { text }] = content.as_slice() else { - return None; - }; - text.contains("").then_some(text) - } - _ => None, - }) - .expect("environment update item should be emitted"); - assert!(environment_update.contains("2026-02-27")); - assert!(environment_update.contains("Europe/Berlin")); - } - - #[tokio::test] - async fn build_settings_update_items_emits_realtime_start_when_session_becomes_live() { - let (session, previous_context) = make_session_and_context().await; - let previous_context = Arc::new(previous_context); - let mut current_context = previous_context - .with_model( - previous_context.model_info.slug.clone(), - &session.services.models_manager, - ) - .await; - current_context.realtime_active = true; - - let update_items = session - .build_settings_update_items( - Some(&previous_context.to_turn_context_item()), - ¤t_context, - ) - .await; - - let developer_texts = developer_input_texts(&update_items); - assert!( - developer_texts - .iter() - .any(|text| text.contains("")), - "expected a realtime start update, got {developer_texts:?}" - ); - } - - #[tokio::test] - async fn build_settings_update_items_emits_realtime_end_when_session_stops_being_live() { - let (session, mut previous_context) = make_session_and_context().await; - previous_context.realtime_active = true; - let mut current_context = previous_context - .with_model( - previous_context.model_info.slug.clone(), - &session.services.models_manager, - ) - .await; - current_context.realtime_active = false; - - let update_items = session - .build_settings_update_items( - Some(&previous_context.to_turn_context_item()), - ¤t_context, - ) - .await; - - let developer_texts = developer_input_texts(&update_items); - assert!( - developer_texts - .iter() - .any(|text| text.contains("Reason: inactive")), - "expected a realtime end update, got {developer_texts:?}" - ); - } - - #[tokio::test] - async fn build_settings_update_items_uses_previous_turn_settings_for_realtime_end() { - let (session, previous_context) = make_session_and_context().await; - let mut previous_context_item = previous_context.to_turn_context_item(); - previous_context_item.realtime_active = None; - let previous_turn_settings = PreviousTurnSettings { - model: previous_context.model_info.slug.clone(), - realtime_active: Some(true), - }; - let mut current_context = previous_context - .with_model( - previous_context.model_info.slug.clone(), - &session.services.models_manager, - ) - .await; - current_context.realtime_active = false; - - session - .set_previous_turn_settings(Some(previous_turn_settings)) - .await; - let update_items = session - .build_settings_update_items(Some(&previous_context_item), ¤t_context) - .await; - - let developer_texts = developer_input_texts(&update_items); - assert!( - developer_texts - .iter() - .any(|text| text.contains("Reason: inactive")), - "expected a realtime end update from previous turn settings, got {developer_texts:?}" - ); - } - - #[tokio::test] - async fn build_initial_context_uses_previous_realtime_state() { - let (session, mut turn_context) = make_session_and_context().await; - turn_context.realtime_active = true; - - let initial_context = session.build_initial_context(&turn_context).await; - let developer_texts = developer_input_texts(&initial_context); - assert!( - developer_texts - .iter() - .any(|text| text.contains("")), - "expected initial context to describe active realtime state, got {developer_texts:?}" - ); - - let previous_context_item = turn_context.to_turn_context_item(); - { - let mut state = session.state.lock().await; - state.set_reference_context_item(Some(previous_context_item)); - } - let resumed_context = session.build_initial_context(&turn_context).await; - let resumed_developer_texts = developer_input_texts(&resumed_context); - assert!( - !resumed_developer_texts - .iter() - .any(|text| text.contains("")), - "did not expect a duplicate realtime update, got {resumed_developer_texts:?}" - ); - } - - #[tokio::test] - async fn build_initial_context_uses_previous_turn_settings_for_realtime_end() { - let (session, turn_context) = make_session_and_context().await; - let previous_turn_settings = PreviousTurnSettings { - model: turn_context.model_info.slug.clone(), - realtime_active: Some(true), - }; - - session - .set_previous_turn_settings(Some(previous_turn_settings)) - .await; - let initial_context = session.build_initial_context(&turn_context).await; - let developer_texts = developer_input_texts(&initial_context); - assert!( - developer_texts - .iter() - .any(|text| text.contains("Reason: inactive")), - "expected initial context to describe an ended realtime session, got {developer_texts:?}" - ); - } - - #[tokio::test] - async fn build_initial_context_restates_realtime_start_when_reference_context_is_missing() { - let (session, mut turn_context) = make_session_and_context().await; - turn_context.realtime_active = true; - let previous_turn_settings = PreviousTurnSettings { - model: turn_context.model_info.slug.clone(), - realtime_active: Some(true), - }; - - session - .set_previous_turn_settings(Some(previous_turn_settings)) - .await; - let initial_context = session.build_initial_context(&turn_context).await; - let developer_texts = developer_input_texts(&initial_context); - assert!( - developer_texts - .iter() - .any(|text| text.contains("")), - "expected initial context to restate active realtime when the reference context is missing, got {developer_texts:?}" - ); - } - - #[tokio::test] - async fn record_context_updates_and_set_reference_context_item_injects_full_context_when_baseline_missing() - { - let (session, turn_context) = make_session_and_context().await; - session - .record_context_updates_and_set_reference_context_item(&turn_context) - .await; - let history = session.clone_history().await; - let initial_context = session.build_initial_context(&turn_context).await; - assert_eq!(history.raw_items().to_vec(), initial_context); - - let current_context = session.reference_context_item().await; - assert_eq!( - serde_json::to_value(current_context).expect("serialize current context item"), - serde_json::to_value(Some(turn_context.to_turn_context_item())) - .expect("serialize expected context item") - ); - } - - #[tokio::test] - async fn record_context_updates_and_set_reference_context_item_reinjects_full_context_after_clear() - { - let (session, turn_context) = make_session_and_context().await; - let compacted_summary = ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: format!("{}\nsummary", crate::compact::SUMMARY_PREFIX), - }], - end_turn: None, - phase: None, - }; - session - .record_into_history(std::slice::from_ref(&compacted_summary), &turn_context) - .await; - session - .record_context_updates_and_set_reference_context_item(&turn_context) - .await; - { - let mut state = session.state.lock().await; - state.set_reference_context_item(None); - } - session - .replace_history(vec![compacted_summary.clone()], None) - .await; - - session - .record_context_updates_and_set_reference_context_item(&turn_context) - .await; - - let history = session.clone_history().await; - let mut expected_history = vec![compacted_summary]; - expected_history.extend(session.build_initial_context(&turn_context).await); - assert_eq!(history.raw_items().to_vec(), expected_history); - } - - #[tokio::test] - async fn record_context_updates_and_set_reference_context_item_persists_baseline_without_emitting_diffs() - { - let (session, previous_context) = make_session_and_context().await; - let next_model = if previous_context.model_info.slug == "gpt-5.1" { - "gpt-5" - } else { - "gpt-5.1" - }; - let turn_context = previous_context - .with_model(next_model.to_string(), &session.services.models_manager) - .await; - let previous_context_item = previous_context.to_turn_context_item(); - { - let mut state = session.state.lock().await; - state.set_reference_context_item(Some(previous_context_item.clone())); - } - let config = session.get_config().await; - let recorder = RolloutRecorder::new( - config.as_ref(), - RolloutRecorderParams::new( - ThreadId::default(), - None, - SessionSource::Exec, - BaseInstructions::default(), - Vec::new(), - EventPersistenceMode::Limited, - ), - None, - None, - ) - .await - .expect("create rollout recorder"); - let rollout_path = recorder.rollout_path().to_path_buf(); - { - let mut rollout = session.services.rollout.lock().await; - *rollout = Some(recorder); - } - - let update_items = session - .build_settings_update_items(Some(&previous_context_item), &turn_context) - .await; - assert_eq!(update_items, Vec::new()); - - session - .record_context_updates_and_set_reference_context_item(&turn_context) - .await; - - assert_eq!( - session.clone_history().await.raw_items().to_vec(), - Vec::new() - ); - assert_eq!( - serde_json::to_value(session.reference_context_item().await) - .expect("serialize current context item"), - serde_json::to_value(Some(turn_context.to_turn_context_item())) - .expect("serialize expected context item") - ); - session.ensure_rollout_materialized().await; - session.flush_rollout().await; - - let InitialHistory::Resumed(resumed) = RolloutRecorder::get_rollout_history(&rollout_path) - .await - .expect("read rollout history") - else { - panic!("expected resumed rollout history"); - }; - let persisted_turn_context = resumed.history.iter().find_map(|item| match item { - RolloutItem::TurnContext(ctx) => Some(ctx.clone()), - _ => None, - }); - assert_eq!( - serde_json::to_value(persisted_turn_context) - .expect("serialize persisted turn context item"), - serde_json::to_value(Some(turn_context.to_turn_context_item())) - .expect("serialize expected turn context item") - ); - } - - #[tokio::test] - async fn build_initial_context_prepends_model_switch_message() { - let (session, turn_context) = make_session_and_context().await; - let previous_turn_settings = PreviousTurnSettings { - model: "previous-regular-model".to_string(), - realtime_active: None, - }; - - session - .set_previous_turn_settings(Some(previous_turn_settings)) - .await; - let initial_context = session.build_initial_context(&turn_context).await; - - let ResponseItem::Message { role, content, .. } = &initial_context[0] else { - panic!("expected developer message"); - }; - assert_eq!(role, "developer"); - let [ContentItem::InputText { text }, ..] = content.as_slice() else { - panic!("expected developer text"); - }; - assert!(text.contains("")); - } - - #[tokio::test] - async fn record_context_updates_and_set_reference_context_item_persists_full_reinjection_to_rollout() - { - let (session, previous_context) = make_session_and_context().await; - let next_model = if previous_context.model_info.slug == "gpt-5.1" { - "gpt-5" - } else { - "gpt-5.1" - }; - let turn_context = previous_context - .with_model(next_model.to_string(), &session.services.models_manager) - .await; - let config = session.get_config().await; - let recorder = RolloutRecorder::new( - config.as_ref(), - RolloutRecorderParams::new( - ThreadId::default(), - None, - SessionSource::Exec, - BaseInstructions::default(), - Vec::new(), - EventPersistenceMode::Limited, - ), - None, - None, - ) - .await - .expect("create rollout recorder"); - let rollout_path = recorder.rollout_path().to_path_buf(); - { - let mut rollout = session.services.rollout.lock().await; - *rollout = Some(recorder); - } - - session - .persist_rollout_items(&[RolloutItem::EventMsg(EventMsg::UserMessage( - UserMessageEvent { - message: "seed rollout".to_string(), - images: None, - local_images: Vec::new(), - text_elements: Vec::new(), - }, - ))]) - .await; - { - let mut state = session.state.lock().await; - state.set_reference_context_item(None); - } - - session - .set_previous_turn_settings(Some(PreviousTurnSettings { - model: previous_context.model_info.slug.clone(), - realtime_active: Some(previous_context.realtime_active), - })) - .await; - session - .record_context_updates_and_set_reference_context_item(&turn_context) - .await; - session.ensure_rollout_materialized().await; - session.flush_rollout().await; - - let InitialHistory::Resumed(resumed) = RolloutRecorder::get_rollout_history(&rollout_path) - .await - .expect("read rollout history") - else { - panic!("expected resumed rollout history"); - }; - let persisted_turn_context = resumed.history.iter().find_map(|item| match item { - RolloutItem::TurnContext(ctx) => Some(ctx.clone()), - _ => None, - }); - - assert_eq!( - serde_json::to_value(persisted_turn_context) - .expect("serialize persisted turn context item"), - serde_json::to_value(Some(turn_context.to_turn_context_item())) - .expect("serialize expected turn context item") - ); - } - - #[tokio::test] - async fn run_user_shell_command_does_not_set_reference_context_item() { - let (session, _turn_context, rx) = make_session_and_context_with_rx().await; - { - let mut state = session.state.lock().await; - state.set_reference_context_item(None); - } - - handlers::run_user_shell_command(&session, "sub-id".to_string(), "echo shell".to_string()) - .await; - - let deadline = StdDuration::from_secs(15); - let start = std::time::Instant::now(); - loop { - let remaining = deadline.saturating_sub(start.elapsed()); - let evt = tokio::time::timeout(remaining, rx.recv()) - .await - .expect("timeout waiting for event") - .expect("event"); - if matches!(evt.msg, EventMsg::TurnComplete(_)) { - break; - } - } - - assert!( - session.reference_context_item().await.is_none(), - "standalone shell tasks should not mutate previous context" - ); - } - - #[derive(Clone, Copy)] - struct NeverEndingTask { - kind: TaskKind, - listen_to_cancellation_token: bool, - } - - #[async_trait::async_trait] - impl SessionTask for NeverEndingTask { - fn kind(&self) -> TaskKind { - self.kind - } - - fn span_name(&self) -> &'static str { - "session_task.never_ending" - } - - async fn run( - self: Arc, - _session: Arc, - _ctx: Arc, - _input: Vec, - cancellation_token: CancellationToken, - ) -> Option { - if self.listen_to_cancellation_token { - cancellation_token.cancelled().await; - return None; - } - loop { - sleep(Duration::from_secs(60)).await; - } - } - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - #[test_log::test] - async fn abort_regular_task_emits_turn_aborted_only() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - let input = vec![UserInput::Text { - text: "hello".to_string(), - text_elements: Vec::new(), - }]; - sess.spawn_task( - Arc::clone(&tc), - input, - NeverEndingTask { - kind: TaskKind::Regular, - listen_to_cancellation_token: false, - }, - ) - .await; - - sess.abort_all_tasks(TurnAbortReason::Interrupted).await; - - // Interrupts persist a model-visible `` marker into history, but there is no - // separate client-visible event for that marker (only `EventMsg::TurnAborted`). - let evt = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) - .await - .expect("timeout waiting for event") - .expect("event"); - match evt.msg { - EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason), - other => panic!("unexpected event: {other:?}"), - } - // No extra events should be emitted after an abort. - assert!(rx.try_recv().is_err()); - } - - #[tokio::test] - async fn abort_gracefully_emits_turn_aborted_only() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - let input = vec![UserInput::Text { - text: "hello".to_string(), - text_elements: Vec::new(), - }]; - sess.spawn_task( - Arc::clone(&tc), - input, - NeverEndingTask { - kind: TaskKind::Regular, - listen_to_cancellation_token: true, - }, - ) - .await; - - sess.abort_all_tasks(TurnAbortReason::Interrupted).await; - - // Even if tasks handle cancellation gracefully, interrupts still result in `TurnAborted` - // being the only client-visible signal. - let evt = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) - .await - .expect("timeout waiting for event") - .expect("event"); - match evt.msg { - EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason), - other => panic!("unexpected event: {other:?}"), - } - // No extra events should be emitted after an abort. - assert!(rx.try_recv().is_err()); - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn task_finish_emits_turn_item_lifecycle_for_leftover_pending_user_input() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - let input = vec![UserInput::Text { - text: "hello".to_string(), - text_elements: Vec::new(), - }]; - sess.spawn_task( - Arc::clone(&tc), - input, - NeverEndingTask { - kind: TaskKind::Regular, - listen_to_cancellation_token: false, - }, - ) - .await; - - while rx.try_recv().is_ok() {} - - sess.inject_response_items(vec![ResponseInputItem::Message { - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "late pending input".to_string(), - }], - }]) - .await - .expect("inject pending input into active turn"); - - sess.on_task_finished(Arc::clone(&tc), None).await; - - let history = sess.clone_history().await; - let expected = ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "late pending input".to_string(), - }], - end_turn: None, - phase: None, - }; - assert!( - history.raw_items().iter().any(|item| item == &expected), - "expected pending input to be persisted into history on turn completion" - ); - - let first = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) - .await - .expect("expected raw response item event") - .expect("channel open"); - assert!(matches!(first.msg, EventMsg::RawResponseItem(_))); - - let second = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) - .await - .expect("expected item started event") - .expect("channel open"); - assert!(matches!( - second.msg, - EventMsg::ItemStarted(ItemStartedEvent { - item: TurnItem::UserMessage(UserMessageItem { content, .. }), - .. - }) if content == vec![UserInput::Text { - text: "late pending input".to_string(), - text_elements: Vec::new(), - }] - )); - - let third = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) - .await - .expect("expected item completed event") - .expect("channel open"); - assert!(matches!( - third.msg, - EventMsg::ItemCompleted(ItemCompletedEvent { - item: TurnItem::UserMessage(UserMessageItem { content, .. }), - .. - }) if content == vec![UserInput::Text { - text: "late pending input".to_string(), - text_elements: Vec::new(), - }] - )); - - let fourth = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) - .await - .expect("expected legacy user message event") - .expect("channel open"); - assert!(matches!( - fourth.msg, - EventMsg::UserMessage(UserMessageEvent { - message, - images, - text_elements, - local_images, - }) if message == "late pending input" - && images == Some(Vec::new()) - && text_elements.is_empty() - && local_images.is_empty() - )); - - let fifth = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) - .await - .expect("expected turn complete event") - .expect("channel open"); - assert!(matches!( - fifth.msg, - EventMsg::TurnComplete(TurnCompleteEvent { - turn_id, - last_agent_message: None, - }) if turn_id == tc.sub_id - )); - } - - #[tokio::test] - async fn steer_input_requires_active_turn() { - let (sess, _tc, _rx) = make_session_and_context_with_rx().await; - let input = vec![UserInput::Text { - text: "steer".to_string(), - text_elements: Vec::new(), - }]; - - let err = sess - .steer_input(input, None) - .await - .expect_err("steering without active turn should fail"); - - assert!(matches!(err, SteerInputError::NoActiveTurn(_))); - } - - #[tokio::test] - async fn steer_input_enforces_expected_turn_id() { - let (sess, tc, _rx) = make_session_and_context_with_rx().await; - let input = vec![UserInput::Text { - text: "hello".to_string(), - text_elements: Vec::new(), - }]; - sess.spawn_task( - Arc::clone(&tc), - input, - NeverEndingTask { - kind: TaskKind::Regular, - listen_to_cancellation_token: false, - }, - ) - .await; - - let steer_input = vec![UserInput::Text { - text: "steer".to_string(), - text_elements: Vec::new(), - }]; - let err = sess - .steer_input(steer_input, Some("different-turn-id")) - .await - .expect_err("mismatched expected turn id should fail"); - - match err { - SteerInputError::ExpectedTurnMismatch { expected, actual } => { - assert_eq!( - (expected, actual), - ("different-turn-id".to_string(), tc.sub_id.clone()) - ); - } - other => panic!("unexpected error: {other:?}"), - } - } - - #[tokio::test] - async fn steer_input_returns_active_turn_id() { - let (sess, tc, _rx) = make_session_and_context_with_rx().await; - let input = vec![UserInput::Text { - text: "hello".to_string(), - text_elements: Vec::new(), - }]; - sess.spawn_task( - Arc::clone(&tc), - input, - NeverEndingTask { - kind: TaskKind::Regular, - listen_to_cancellation_token: false, - }, - ) - .await; - - let steer_input = vec![UserInput::Text { - text: "steer".to_string(), - text_elements: Vec::new(), - }]; - let turn_id = sess - .steer_input(steer_input, Some(&tc.sub_id)) - .await - .expect("steering with matching expected turn id should succeed"); - - assert_eq!(turn_id, tc.sub_id); - assert!(sess.has_pending_input().await); - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn abort_review_task_emits_exited_then_aborted_and_records_history() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - let input = vec![UserInput::Text { - text: "start review".to_string(), - text_elements: Vec::new(), - }]; - sess.spawn_task(Arc::clone(&tc), input, ReviewTask::new()) - .await; - - sess.abort_all_tasks(TurnAbortReason::Interrupted).await; - - // Aborting a review task should exit review mode before surfacing the abort to the client. - // We scan for these events (rather than relying on fixed ordering) since unrelated events - // may interleave. - let mut exited_review_mode_idx = None; - let mut turn_aborted_idx = None; - let mut idx = 0usize; - let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(3); - while tokio::time::Instant::now() < deadline { - let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); - let evt = tokio::time::timeout(remaining, rx.recv()) - .await - .expect("timeout waiting for event") - .expect("event"); - let event_idx = idx; - idx = idx.saturating_add(1); - match evt.msg { - EventMsg::ExitedReviewMode(ev) => { - assert!(ev.review_output.is_none()); - exited_review_mode_idx = Some(event_idx); - } - EventMsg::TurnAborted(ev) => { - assert_eq!(TurnAbortReason::Interrupted, ev.reason); - turn_aborted_idx = Some(event_idx); - break; - } - _ => {} - } - } - assert!( - exited_review_mode_idx.is_some(), - "expected ExitedReviewMode after abort" - ); - assert!( - turn_aborted_idx.is_some(), - "expected TurnAborted after abort" - ); - assert!( - exited_review_mode_idx.unwrap() < turn_aborted_idx.unwrap(), - "expected ExitedReviewMode before TurnAborted" - ); - - let history = sess.clone_history().await; - // The `` marker is silent in the event stream, so verify it is still - // recorded in history for the model. - assert!( - history.raw_items().iter().any(|item| { - let ResponseItem::Message { role, content, .. } = item else { - return false; - }; - if role != "user" { - return false; - } - content.iter().any(|content_item| { - let ContentItem::InputText { text } = content_item else { - return false; - }; - text.contains(crate::contextual_user_message::TURN_ABORTED_OPEN_TAG) - }) - }), - "expected a model-visible turn aborted marker in history after interrupt" - ); - } - - #[tokio::test] - async fn fatal_tool_error_stops_turn_and_reports_error() { - let (session, turn_context, _rx) = make_session_and_context_with_rx().await; - let tools = { - session - .services - .mcp_connection_manager - .read() - .await - .list_all_tools() - .await - }; - let app_tools = Some(tools.clone()); - let router = ToolRouter::from_config( - &turn_context.tools_config, - Some( - tools - .into_iter() - .map(|(name, tool)| (name, tool.tool)) - .collect(), - ), - app_tools, - turn_context.dynamic_tools.as_slice(), - ); - let item = ResponseItem::CustomToolCall { - id: None, - status: None, - call_id: "call-1".to_string(), - name: "shell".to_string(), - input: "{}".to_string(), - }; - - let call = ToolRouter::build_tool_call(session.as_ref(), item.clone()) - .await - .expect("build tool call") - .expect("tool call present"); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); - let err = router - .dispatch_tool_call( - Arc::clone(&session), - Arc::clone(&turn_context), - tracker, - call, - ToolCallSource::Direct, - ) - .await - .expect_err("expected fatal error"); - - match err { - FunctionCallError::Fatal(message) => { - assert_eq!(message, "tool shell invoked with incompatible payload"); - } - other => panic!("expected FunctionCallError::Fatal, got {other:?}"), - } - } - - async fn sample_rollout( - session: &Session, - _turn_context: &TurnContext, - ) -> (Vec, Vec) { - let mut rollout_items = Vec::new(); - let mut live_history = ContextManager::new(); - - // Use the same turn_context source as record_initial_history so model_info (and thus - // personality_spec) matches reconstruction. - let reconstruction_turn = session.new_default_turn().await; - let mut initial_context = session - .build_initial_context(reconstruction_turn.as_ref()) - .await; - // Ensure personality_spec is present when Personality is enabled, so expected matches - // what reconstruction produces (build_initial_context may omit it when baked into model). - if !initial_context.iter().any(|m| { - matches!(m, ResponseItem::Message { role, content, .. } - if role == "developer" - && content.iter().any(|c| { - matches!(c, ContentItem::InputText { text } if text.contains("")) - })) - }) - && let Some(p) = reconstruction_turn.personality - && session.features.enabled(Feature::Personality) - && let Some(personality_message) = reconstruction_turn - .model_info - .model_messages - .as_ref() - .and_then(|m| m.get_personality_message(Some(p)).filter(|s| !s.is_empty())) - { - let msg = - DeveloperInstructions::personality_spec_message(personality_message).into(); - let insert_at = initial_context - .iter() - .position(|m| matches!(m, ResponseItem::Message { role, .. } if role == "developer")) - .map(|i| i + 1) - .unwrap_or(0); - initial_context.insert(insert_at, msg); - } - for item in &initial_context { - rollout_items.push(RolloutItem::ResponseItem(item.clone())); - } - live_history.record_items( - initial_context.iter(), - reconstruction_turn.truncation_policy, - ); - - let user1 = ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "first user".to_string(), - }], - end_turn: None, - phase: None, - }; - live_history.record_items( - std::iter::once(&user1), - reconstruction_turn.truncation_policy, - ); - rollout_items.push(RolloutItem::ResponseItem(user1)); - - let assistant1 = ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: "assistant reply one".to_string(), - }], - end_turn: None, - phase: None, - }; - live_history.record_items( - std::iter::once(&assistant1), - reconstruction_turn.truncation_policy, - ); - rollout_items.push(RolloutItem::ResponseItem(assistant1)); - - let summary1 = "summary one"; - let snapshot1 = live_history - .clone() - .for_prompt(&reconstruction_turn.model_info.input_modalities); - let user_messages1 = collect_user_messages(&snapshot1); - let rebuilt1 = compact::build_compacted_history(Vec::new(), &user_messages1, summary1); - live_history.replace(rebuilt1); - rollout_items.push(RolloutItem::Compacted(CompactedItem { - message: summary1.to_string(), - replacement_history: None, - })); - - let user2 = ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "second user".to_string(), - }], - end_turn: None, - phase: None, - }; - live_history.record_items( - std::iter::once(&user2), - reconstruction_turn.truncation_policy, - ); - rollout_items.push(RolloutItem::ResponseItem(user2)); - - let assistant2 = ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: "assistant reply two".to_string(), - }], - end_turn: None, - phase: None, - }; - live_history.record_items( - std::iter::once(&assistant2), - reconstruction_turn.truncation_policy, - ); - rollout_items.push(RolloutItem::ResponseItem(assistant2)); - - let summary2 = "summary two"; - let snapshot2 = live_history - .clone() - .for_prompt(&reconstruction_turn.model_info.input_modalities); - let user_messages2 = collect_user_messages(&snapshot2); - let rebuilt2 = compact::build_compacted_history(Vec::new(), &user_messages2, summary2); - live_history.replace(rebuilt2); - rollout_items.push(RolloutItem::Compacted(CompactedItem { - message: summary2.to_string(), - replacement_history: None, - })); - - let user3 = ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "third user".to_string(), - }], - end_turn: None, - phase: None, - }; - live_history.record_items( - std::iter::once(&user3), - reconstruction_turn.truncation_policy, - ); - rollout_items.push(RolloutItem::ResponseItem(user3)); - - let assistant3 = ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: "assistant reply three".to_string(), - }], - end_turn: None, - phase: None, - }; - live_history.record_items( - std::iter::once(&assistant3), - reconstruction_turn.truncation_policy, - ); - rollout_items.push(RolloutItem::ResponseItem(assistant3)); - - ( - rollout_items, - live_history.for_prompt(&reconstruction_turn.model_info.input_modalities), - ) - } - - #[tokio::test] - async fn rejects_escalated_permissions_when_policy_not_on_request() { - use crate::exec::ExecParams; - use crate::protocol::AskForApproval; - use crate::protocol::SandboxPolicy; - use crate::sandboxing::SandboxPermissions; - use crate::turn_diff_tracker::TurnDiffTracker; - use std::collections::HashMap; - - let (session, mut turn_context_raw) = make_session_and_context().await; - // Ensure policy is NOT OnRequest so the early rejection path triggers - turn_context_raw - .approval_policy - .set(AskForApproval::OnFailure) - .expect("test setup should allow updating approval policy"); - let session = Arc::new(session); - let mut turn_context = Arc::new(turn_context_raw); - - let timeout_ms = 1000; - let sandbox_permissions = SandboxPermissions::RequireEscalated; - let params = ExecParams { - command: if cfg!(windows) { - vec![ - "cmd.exe".to_string(), - "/C".to_string(), - "echo hi".to_string(), - ] - } else { - vec![ - "/bin/sh".to_string(), - "-c".to_string(), - "echo hi".to_string(), - ] - }, - cwd: turn_context.cwd.clone(), - expiration: timeout_ms.into(), - env: HashMap::new(), - network: None, - sandbox_permissions, - windows_sandbox_level: turn_context.windows_sandbox_level, - justification: Some("test".to_string()), - arg0: None, - }; - - let params2 = ExecParams { - sandbox_permissions: SandboxPermissions::UseDefault, - command: params.command.clone(), - cwd: params.cwd.clone(), - expiration: timeout_ms.into(), - env: HashMap::new(), - network: None, - windows_sandbox_level: turn_context.windows_sandbox_level, - justification: params.justification.clone(), - arg0: None, - }; - - let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); - - let tool_name = "shell"; - let call_id = "test-call".to_string(); - - let handler = ShellHandler; - let resp = handler - .handle(ToolInvocation { - session: Arc::clone(&session), - turn: Arc::clone(&turn_context), - tracker: Arc::clone(&turn_diff_tracker), - call_id, - tool_name: tool_name.to_string(), - payload: ToolPayload::Function { - arguments: serde_json::json!({ - "command": params.command.clone(), - "workdir": Some(turn_context.cwd.to_string_lossy().to_string()), - "timeout_ms": params.expiration.timeout_ms(), - "sandbox_permissions": params.sandbox_permissions, - "justification": params.justification.clone(), - }) - .to_string(), - }, - }) - .await; - - let Err(FunctionCallError::RespondToModel(output)) = resp else { - panic!("expected error result"); - }; - - let expected = format!( - "approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}", - policy = turn_context.approval_policy.value() - ); - - pretty_assertions::assert_eq!(output, expected); - - // Now retry the same command WITHOUT escalated permissions; should succeed. - // Force DangerFullAccess to avoid platform sandbox dependencies in tests. - Arc::get_mut(&mut turn_context) - .expect("unique turn context Arc") - .sandbox_policy - .set(SandboxPolicy::DangerFullAccess) - .expect("test setup should allow updating sandbox policy"); - - let resp2 = handler - .handle(ToolInvocation { - session: Arc::clone(&session), - turn: Arc::clone(&turn_context), - tracker: Arc::clone(&turn_diff_tracker), - call_id: "test-call-2".to_string(), - tool_name: tool_name.to_string(), - payload: ToolPayload::Function { - arguments: serde_json::json!({ - "command": params2.command.clone(), - "workdir": Some(turn_context.cwd.to_string_lossy().to_string()), - "timeout_ms": params2.expiration.timeout_ms(), - "sandbox_permissions": params2.sandbox_permissions, - "justification": params2.justification.clone(), - }) - .to_string(), - }, - }) - .await; - - let output = match resp2.expect("expected Ok result") { - ToolOutput::Function { - body: FunctionCallOutputBody::Text(content), - .. - } => content, - _ => panic!("unexpected tool output"), - }; - - #[derive(Deserialize, PartialEq, Eq, Debug)] - struct ResponseExecMetadata { - exit_code: i32, - } - - #[derive(Deserialize)] - struct ResponseExecOutput { - output: String, - metadata: ResponseExecMetadata, - } - - let exec_output: ResponseExecOutput = - serde_json::from_str(&output).expect("valid exec output json"); - - pretty_assertions::assert_eq!(exec_output.metadata, ResponseExecMetadata { exit_code: 0 }); - assert!(exec_output.output.contains("hi")); - } - #[tokio::test] - async fn unified_exec_rejects_escalated_permissions_when_policy_not_on_request() { - use crate::protocol::AskForApproval; - use crate::sandboxing::SandboxPermissions; - use crate::turn_diff_tracker::TurnDiffTracker; - - let (session, mut turn_context_raw) = make_session_and_context().await; - turn_context_raw - .approval_policy - .set(AskForApproval::OnFailure) - .expect("test setup should allow updating approval policy"); - let session = Arc::new(session); - let turn_context = Arc::new(turn_context_raw); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); - - let handler = UnifiedExecHandler; - let resp = handler - .handle(ToolInvocation { - session: Arc::clone(&session), - turn: Arc::clone(&turn_context), - tracker: Arc::clone(&tracker), - call_id: "exec-call".to_string(), - tool_name: "exec_command".to_string(), - payload: ToolPayload::Function { - arguments: serde_json::json!({ - "cmd": "echo hi", - "sandbox_permissions": SandboxPermissions::RequireEscalated, - "justification": "need unsandboxed execution", - }) - .to_string(), - }, - }) - .await; - - let Err(FunctionCallError::RespondToModel(output)) = resp else { - panic!("expected error result"); - }; - - let expected = format!( - "approval policy is {policy:?}; reject command — you cannot ask for escalated permissions if the approval policy is {policy:?}", - policy = turn_context.approval_policy.value() - ); - - pretty_assertions::assert_eq!(output, expected); - } From a41777e8f8a2a3e47caed26a7e6234f09bfa7d24 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Fri, 6 Mar 2026 16:15:19 -0500 Subject: [PATCH 6/7] only v2, not v1 --- codex-rs/app-server-protocol/src/protocol/v1.rs | 4 +--- codex-rs/core/src/config/mod.rs | 2 +- codex-rs/core/src/config/profile.rs | 1 - 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/codex-rs/app-server-protocol/src/protocol/v1.rs b/codex-rs/app-server-protocol/src/protocol/v1.rs index c00ec2d5b1c..d393f97f72b 100644 --- a/codex-rs/app-server-protocol/src/protocol/v1.rs +++ b/codex-rs/app-server-protocol/src/protocol/v1.rs @@ -7,7 +7,6 @@ use codex_protocol::config_types::ReasoningSummary; use codex_protocol::config_types::SandboxMode; use codex_protocol::config_types::ServiceTier; use codex_protocol::config_types::Verbosity; -use codex_protocol::config_types::WebSearchToolConfig; use codex_protocol::models::ResponseItem; use codex_protocol::openai_models::ReasoningEffort; use codex_protocol::parse_command::ParsedCommand; @@ -386,13 +385,12 @@ pub struct Profile { pub model_reasoning_summary: Option, pub model_verbosity: Option, pub chatgpt_base_url: Option, - pub tools: Option, } #[derive(Deserialize, Debug, Clone, PartialEq, Serialize, JsonSchema, TS)] #[serde(rename_all = "camelCase")] pub struct Tools { - pub web_search: Option, + pub web_search: Option, pub view_image: Option, } diff --git a/codex-rs/core/src/config/mod.rs b/codex-rs/core/src/config/mod.rs index 9294a9bd3d7..edd5e6720d9 100644 --- a/codex-rs/core/src/config/mod.rs +++ b/codex-rs/core/src/config/mod.rs @@ -1423,7 +1423,7 @@ pub struct AgentRoleToml { impl From for Tools { fn from(tools_toml: ToolsToml) -> Self { Self { - web_search: tools_toml.web_search, + web_search: tools_toml.web_search.is_some().then_some(true), view_image: tools_toml.view_image, } } diff --git a/codex-rs/core/src/config/profile.rs b/codex-rs/core/src/config/profile.rs index 3ee213513d5..ce454ff0a85 100644 --- a/codex-rs/core/src/config/profile.rs +++ b/codex-rs/core/src/config/profile.rs @@ -71,7 +71,6 @@ impl From for codex_app_server_protocol::Profile { model_reasoning_effort: config_profile.model_reasoning_effort, model_reasoning_summary: config_profile.model_reasoning_summary, model_verbosity: config_profile.model_verbosity, - tools: config_profile.tools.map(Into::into), chatgpt_base_url: config_profile.chatgpt_base_url, } } From 3ae497432892033bb2ae16c0c8be6649965c550e Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Fri, 6 Mar 2026 19:10:59 -0500 Subject: [PATCH 7/7] make merging less verbose --- codex-rs/core/src/config/mod.rs | 35 +---------- codex-rs/protocol/src/config_types.rs | 91 +++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 34 deletions(-) diff --git a/codex-rs/core/src/config/mod.rs b/codex-rs/core/src/config/mod.rs index edd5e6720d9..e808632fa3e 100644 --- a/codex-rs/core/src/config/mod.rs +++ b/codex-rs/core/src/config/mod.rs @@ -68,7 +68,6 @@ use codex_protocol::config_types::ServiceTier; use codex_protocol::config_types::TrustLevel; use codex_protocol::config_types::Verbosity; use codex_protocol::config_types::WebSearchConfig; -use codex_protocol::config_types::WebSearchLocation; use codex_protocol::config_types::WebSearchMode; use codex_protocol::config_types::WebSearchToolConfig; use codex_protocol::config_types::WindowsSandboxLevel; @@ -1660,39 +1659,7 @@ fn resolve_web_search_config( (None, None) => None, (Some(base), None) => Some(base.clone().into()), (None, Some(profile)) => Some(profile.clone().into()), - (Some(base), Some(profile)) => Some( - WebSearchToolConfig { - context_size: profile.context_size.or(base.context_size), - allowed_domains: profile - .allowed_domains - .clone() - .or_else(|| base.allowed_domains.clone()), - location: match (base.location.as_ref(), profile.location.as_ref()) { - (None, None) => None, - (Some(base_location), None) => Some(base_location.clone()), - (None, Some(profile_location)) => Some(profile_location.clone()), - (Some(base_location), Some(profile_location)) => Some(WebSearchLocation { - country: profile_location - .country - .clone() - .or_else(|| base_location.country.clone()), - region: profile_location - .region - .clone() - .or_else(|| base_location.region.clone()), - city: profile_location - .city - .clone() - .or_else(|| base_location.city.clone()), - timezone: profile_location - .timezone - .clone() - .or_else(|| base_location.timezone.clone()), - }), - }, - } - .into(), - ), + (Some(base), Some(profile)) => Some(base.merge(profile).into()), } } diff --git a/codex-rs/protocol/src/config_types.rs b/codex-rs/protocol/src/config_types.rs index cb4f934d5e2..4261bb8d1d8 100644 --- a/codex-rs/protocol/src/config_types.rs +++ b/codex-rs/protocol/src/config_types.rs @@ -131,6 +131,17 @@ pub struct WebSearchLocation { pub timezone: Option, } +impl WebSearchLocation { + pub fn merge(&self, other: &Self) -> Self { + Self { + country: other.country.clone().or_else(|| self.country.clone()), + region: other.region.clone().or_else(|| self.region.clone()), + city: other.city.clone().or_else(|| self.city.clone()), + timezone: other.timezone.clone().or_else(|| self.timezone.clone()), + } + } +} + #[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq, Eq, JsonSchema, TS)] #[schemars(deny_unknown_fields)] pub struct WebSearchToolConfig { @@ -139,6 +150,24 @@ pub struct WebSearchToolConfig { pub location: Option, } +impl WebSearchToolConfig { + pub fn merge(&self, other: &Self) -> Self { + Self { + context_size: other.context_size.or(self.context_size), + allowed_domains: other + .allowed_domains + .clone() + .or_else(|| self.allowed_domains.clone()), + location: match (&self.location, &other.location) { + (Some(location), Some(other_location)) => Some(location.merge(other_location)), + (Some(location), None) => Some(location.clone()), + (None, Some(other_location)) => Some(other_location.clone()), + (None, None) => None, + }, + } + } +} + #[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq, Eq, JsonSchema, TS)] #[schemars(deny_unknown_fields)] pub struct WebSearchFilters { @@ -453,4 +482,66 @@ mod tests { assert!(!ModeKind::PairProgramming.is_tui_visible()); assert!(!ModeKind::Execute.is_tui_visible()); } + + #[test] + fn web_search_location_merge_prefers_overlay_values() { + let base = WebSearchLocation { + country: Some("US".to_string()), + region: Some("CA".to_string()), + city: None, + timezone: Some("America/Los_Angeles".to_string()), + }; + let overlay = WebSearchLocation { + country: None, + region: Some("WA".to_string()), + city: Some("Seattle".to_string()), + timezone: None, + }; + + let expected = WebSearchLocation { + country: Some("US".to_string()), + region: Some("WA".to_string()), + city: Some("Seattle".to_string()), + timezone: Some("America/Los_Angeles".to_string()), + }; + + assert_eq!(expected, base.merge(&overlay)); + } + + #[test] + fn web_search_tool_config_merge_prefers_overlay_values() { + let base = WebSearchToolConfig { + context_size: Some(WebSearchContextSize::Low), + allowed_domains: Some(vec!["openai.com".to_string()]), + location: Some(WebSearchLocation { + country: Some("US".to_string()), + region: Some("CA".to_string()), + city: None, + timezone: Some("America/Los_Angeles".to_string()), + }), + }; + let overlay = WebSearchToolConfig { + context_size: Some(WebSearchContextSize::High), + allowed_domains: None, + location: Some(WebSearchLocation { + country: None, + region: Some("WA".to_string()), + city: Some("Seattle".to_string()), + timezone: None, + }), + }; + + let expected = WebSearchToolConfig { + context_size: Some(WebSearchContextSize::High), + allowed_domains: Some(vec!["openai.com".to_string()]), + location: Some(WebSearchLocation { + country: Some("US".to_string()), + region: Some("WA".to_string()), + city: Some("Seattle".to_string()), + timezone: Some("America/Los_Angeles".to_string()), + }), + }; + + assert_eq!(expected, base.merge(&overlay)); + } }