Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions codex-rs/app-server-protocol/schema/json/EventMsg.json
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,12 @@
"null"
]
},
"saved_path": {
"type": [
"string",
"null"
]
},
"status": {
"type": "string"
},
Expand Down Expand Up @@ -6069,6 +6075,12 @@
"null"
]
},
"saved_path": {
"type": [
"string",
"null"
]
},
"status": {
"type": "string"
},
Expand Down Expand Up @@ -7214,6 +7226,12 @@
"null"
]
},
"saved_path": {
"type": [
"string",
"null"
]
},
"status": {
"type": "string"
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2652,6 +2652,12 @@
"null"
]
},
"saved_path": {
"type": [
"string",
"null"
]
},
"status": {
"type": "string"
},
Expand Down Expand Up @@ -7445,6 +7451,12 @@
"null"
]
},
"saved_path": {
"type": [
"string",
"null"
]
},
"status": {
"type": "string"
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4189,6 +4189,12 @@
"null"
]
},
"saved_path": {
"type": [
"string",
"null"
]
},
"status": {
"type": "string"
},
Expand Down Expand Up @@ -13847,6 +13853,12 @@
"null"
]
},
"saved_path": {
"type": [
"string",
"null"
]
},
"status": {
"type": "string"
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.

export type ImageGenerationEndEvent = { call_id: string, status: string, revised_prompt?: string, result: string, };
export type ImageGenerationEndEvent = { call_id: string, status: string, revised_prompt?: string, result: string, saved_path?: string, };
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.

export type ImageGenerationItem = { id: string, status: string, revised_prompt?: string, result: string, };
export type ImageGenerationItem = { id: string, status: string, revised_prompt?: string, result: string, saved_path?: string, };
4 changes: 2 additions & 2 deletions codex-rs/core/src/client_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ pub(crate) mod tools {
#[serde(rename = "local_shell")]
LocalShell {},
#[serde(rename = "image_generation")]
ImageGeneration {},
ImageGeneration { output_format: String },
// TODO: Understand why we get an error on web_search although the API docs say it's supported.
// https://platform.openai.com/docs/guides/tools-web-search?api-mode=responses#:~:text=%7B%20type%3A%20%22web_search%22%20%7D%2C
// The `external_web_access` field determines whether the web search is over cached or live content.
Expand All @@ -186,7 +186,7 @@ pub(crate) mod tools {
match self {
ToolSpec::Function(tool) => tool.name.as_str(),
ToolSpec::LocalShell {} => "local_shell",
ToolSpec::ImageGeneration {} => "image_generation",
ToolSpec::ImageGeneration { .. } => "image_generation",
ToolSpec::WebSearch { .. } => "web_search",
ToolSpec::Freeform(tool) => tool.name.as_str(),
}
Expand Down
8 changes: 6 additions & 2 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6217,7 +6217,9 @@ async fn handle_assistant_item_done_in_plan_mode(
{
maybe_complete_plan_item_from_message(sess, turn_context, state, item).await;

if let Some(turn_item) = handle_non_tool_response_item(item, true) {
if let Some(turn_item) =
handle_non_tool_response_item(item, true, Some(&turn_context.cwd)).await
{
emit_turn_item_in_plan_mode(
sess,
turn_context,
Expand Down Expand Up @@ -6396,7 +6398,9 @@ async fn try_run_sampling_request(
needs_follow_up |= output_result.needs_follow_up;
}
ResponseEvent::OutputItemAdded(item) => {
if let Some(turn_item) = handle_non_tool_response_item(&item, plan_mode) {
if let Some(turn_item) =
handle_non_tool_response_item(&item, plan_mode, Some(&turn_context.cwd)).await
{
let mut turn_item = turn_item;
let mut seeded_parsed: Option<ParsedAssistantTextDelta> = None;
let mut seeded_item_id: Option<String> = None;
Expand Down
1 change: 1 addition & 0 deletions codex-rs/core/src/event_mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ pub fn parse_turn_item(item: &ResponseItem) -> Option<TurnItem> {
status: status.clone(),
revised_prompt: revised_prompt.clone(),
result: result.clone(),
saved_path: None,
},
)),
_ => None,
Expand Down
149 changes: 144 additions & 5 deletions codex-rs/core/src/stream_events_utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use std::path::Path;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;

use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use codex_protocol::config_types::ModeKind;
use codex_protocol::items::TurnItem;
use codex_utils_stream_parser::strip_citations;
Expand Down Expand Up @@ -50,6 +54,34 @@ pub(crate) fn raw_assistant_output_text_from_item(item: &ResponseItem) -> Option
None
}

async fn save_image_generation_result_to_cwd(
cwd: &Path,
call_id: &str,
result: &str,
) -> Result<PathBuf> {
let bytes = BASE64_STANDARD
.decode(result.trim().as_bytes())
.map_err(|err| {
CodexErr::InvalidRequest(format!("invalid image generation payload: {err}"))
})?;
let mut file_stem: String = call_id
.chars()
.map(|ch| {
if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
ch
} else {
'_'
}
})
.collect();
if file_stem.is_empty() {
file_stem = "generated_image".to_string();
}
let path = cwd.join(format!("{file_stem}.png"));
tokio::fs::write(&path, bytes).await?;
Ok(path)
}

/// Persist a completed model response item and record any cited memory usage.
pub(crate) async fn record_completed_response_item(
sess: &Session,
Expand Down Expand Up @@ -157,13 +189,16 @@ pub(crate) async fn handle_output_item_done(
}
// No tool call: convert messages/reasoning into turn items and mark them as complete.
Ok(None) => {
if let Some(turn_item) = handle_non_tool_response_item(&item, plan_mode) {
if let Some(turn_item) =
handle_non_tool_response_item(&item, plan_mode, Some(&ctx.turn_context.cwd)).await
{
if previously_active_item.is_none() {
let mut started_item = turn_item.clone();
if let TurnItem::ImageGeneration(item) = &mut started_item {
item.status = "in_progress".to_string();
item.revised_prompt = None;
item.result.clear();
item.saved_path = None;
}
ctx.sess
.emit_turn_item_started(&ctx.turn_context, &started_item)
Expand Down Expand Up @@ -240,9 +275,10 @@ pub(crate) async fn handle_output_item_done(
Ok(output)
}

pub(crate) fn handle_non_tool_response_item(
pub(crate) async fn handle_non_tool_response_item(
item: &ResponseItem,
plan_mode: bool,
image_output_cwd: Option<&Path>,
) -> Option<TurnItem> {
debug!(?item, "Output item");

Expand All @@ -264,6 +300,24 @@ pub(crate) fn handle_non_tool_response_item(
agent_message.content =
vec![codex_protocol::items::AgentMessageContent::Text { text: stripped }];
}
if let TurnItem::ImageGeneration(image_item) = &mut turn_item
&& let Some(cwd) = image_output_cwd
{
match save_image_generation_result_to_cwd(cwd, &image_item.id, &image_item.result)
.await
{
Ok(path) => {
image_item.saved_path = Some(path.to_string_lossy().into_owned());
}
Err(err) => {
tracing::warn!(
call_id = %image_item.id,
cwd = %cwd.display(),
"failed to save generated image: {err}"
);
}
}
}
Some(turn_item)
}
ResponseItem::FunctionCallOutput { .. } | ResponseItem::CustomToolCallOutput { .. } => {
Expand Down Expand Up @@ -326,10 +380,13 @@ pub(crate) fn response_input_to_response_item(input: &ResponseInputItem) -> Opti
mod tests {
use super::handle_non_tool_response_item;
use super::last_assistant_message_from_item;
use super::save_image_generation_result_to_cwd;
use crate::error::CodexErr;
use codex_protocol::items::TurnItem;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use pretty_assertions::assert_eq;
use tempfile::tempdir;

fn assistant_output_text(text: &str) -> ResponseItem {
ResponseItem::Message {
Expand All @@ -343,12 +400,14 @@ mod tests {
}
}

#[test]
fn handle_non_tool_response_item_strips_citations_from_assistant_message() {
#[tokio::test]
async fn handle_non_tool_response_item_strips_citations_from_assistant_message() {
let item = assistant_output_text("hello<oai-mem-citation>doc1</oai-mem-citation> world");

let turn_item =
handle_non_tool_response_item(&item, false).expect("assistant message should parse");
handle_non_tool_response_item(&item, false, Some(std::path::Path::new(".")))
.await
.expect("assistant message should parse");

let TurnItem::AgentMessage(agent_message) = turn_item else {
panic!("expected agent message");
Expand Down Expand Up @@ -388,4 +447,84 @@ mod tests {

assert_eq!(last_assistant_message_from_item(&item, true), None);
}

#[tokio::test]
async fn save_image_generation_result_saves_base64_to_png_in_cwd() {
let dir = tempdir().expect("tempdir");

let saved_path = save_image_generation_result_to_cwd(dir.path(), "ig_123", "Zm9v")
.await
.expect("image should be saved");

assert_eq!(
saved_path.file_name().and_then(|v| v.to_str()),
Some("ig_123.png")
);
assert_eq!(std::fs::read(saved_path).expect("saved file"), b"foo");
}

#[tokio::test]
async fn save_image_generation_result_rejects_data_url_payload() {
let dir = tempdir().expect("tempdir");
let result = "data:image/jpeg;base64,Zm9v";

let err = save_image_generation_result_to_cwd(dir.path(), "ig_456", result)
.await
.expect_err("data url payload should error");
assert!(matches!(err, CodexErr::InvalidRequest(_)));
}

#[tokio::test]
async fn save_image_generation_result_overwrites_existing_file() {
let dir = tempdir().expect("tempdir");
let existing_path = dir.path().join("ig_123.png");
std::fs::write(&existing_path, b"existing").expect("seed existing image");

let saved_path = save_image_generation_result_to_cwd(dir.path(), "ig_123", "Zm9v")
.await
.expect("image should be saved");

assert_eq!(
saved_path.file_name().and_then(|v| v.to_str()),
Some("ig_123.png")
);
assert_eq!(std::fs::read(saved_path).expect("saved file"), b"foo");
}

#[tokio::test]
async fn save_image_generation_result_sanitizes_call_id_for_output_path() {
let dir = tempdir().expect("tempdir");

let saved_path = save_image_generation_result_to_cwd(dir.path(), "../ig/..", "Zm9v")
.await
.expect("image should be saved");

assert_eq!(saved_path.parent(), Some(dir.path()));
assert_eq!(
saved_path.file_name().and_then(|v| v.to_str()),
Some("___ig___.png")
);
assert_eq!(std::fs::read(saved_path).expect("saved file"), b"foo");
}

#[tokio::test]
async fn save_image_generation_result_rejects_non_standard_base64() {
let dir = tempdir().expect("tempdir");

let err = save_image_generation_result_to_cwd(dir.path(), "ig_urlsafe", "_-8")
.await
.expect_err("non-standard base64 should error");
assert!(matches!(err, CodexErr::InvalidRequest(_)));
}

#[tokio::test]
async fn save_image_generation_result_rejects_non_base64_data_urls() {
let dir = tempdir().expect("tempdir");

let err =
save_image_generation_result_to_cwd(dir.path(), "ig_svg", "data:image/svg+xml,<svg/>")
.await
.expect_err("non-base64 data url should error");
assert!(matches!(err, CodexErr::InvalidRequest(_)));
}
}
Loading
Loading